#![allow(static_mut_refs)]
use std::{borrow::Cow, ops::DerefMut};
use crate::{
extension::resolver::{Subscription, SubscriptionCallback},
host_io::logger::HostLogger,
types::Configuration,
wit::{self, Error},
};
use super::extension::AnyExtension;
type InitFn =
Box<dyn FnOnce(Vec<(String, wit::Schema)>, Configuration) -> Result<Box<dyn AnyExtension>, crate::types::Error>>;
static mut INIT_FN: Option<InitFn> = None;
static mut EXTENSION: Option<Box<dyn AnyExtension>> = None;
static mut SUBSCRIPTION: Option<SubscriptionState> = None;
static mut EVENT_QUEUE: Option<wit::EventQueue> = None;
static mut CAN_SKIP_SENDING_EVENTS: bool = false;
enum SubscriptionState {
Uninitialized {
ctx: Box<crate::types::AuthorizedOperationContext>,
prepared: Vec<u8>,
callback: SubscriptionCallback<'static>,
},
Initialized(Box<dyn Subscription>),
}
pub(super) fn init(
subgraph_schemas: Vec<(String, wit::Schema)>,
config: Configuration,
can_skip_sending_events: bool,
host_log_level: String,
) -> Result<(), Error> {
let mut builder = env_filter::Builder::new();
let host_log_level = parse_host_level(host_log_level);
builder.parse(&host_log_level);
let filter = builder.build();
let logger = HostLogger { filter };
log::set_boxed_logger(Box::new(logger)).expect("Failed to set logger");
log::set_max_level(log::LevelFilter::Trace);
unsafe {
let init = std::mem::take(&mut INIT_FN).expect("Resolver extension not initialized correctly.");
EXTENSION = Some(init(subgraph_schemas, config)?);
CAN_SKIP_SENDING_EVENTS = can_skip_sending_events;
}
Ok(())
}
pub(crate) fn with_event_queue<F, T>(event_queue: wit::EventQueue, f: F) -> T
where
F: FnOnce() -> T,
{
unsafe {
EVENT_QUEUE = Some(event_queue);
}
let res = f();
unsafe {
EVENT_QUEUE = None;
}
res
}
pub(crate) fn can_skip_sending_events() -> bool {
unsafe { CAN_SKIP_SENDING_EVENTS }
}
pub(crate) fn queue_event(name: &str, data: &[u8]) {
if let Some(queue) = unsafe { EVENT_QUEUE.as_ref() } {
queue.push(name, data);
}
}
pub(super) fn set_event_queue(event_queue: wit::EventQueue) {
unsafe {
EVENT_QUEUE = Some(event_queue);
}
}
pub(super) fn drop_event_queue() {
unsafe {
EVENT_QUEUE = None;
}
}
#[doc(hidden)]
pub(crate) fn register_extension(f: InitFn) {
unsafe {
INIT_FN = Some(f);
}
}
pub(super) fn extension() -> Result<&'static mut dyn AnyExtension, Error> {
unsafe {
EXTENSION.as_deref_mut().ok_or_else(|| Error {
message: "Extension was not initialized correctly.".to_string(),
extensions: Vec::new(),
})
}
}
pub(super) fn set_subscription_callback(
ctx: Box<crate::types::AuthorizedOperationContext>,
prepared: Vec<u8>,
callback: SubscriptionCallback<'static>,
) {
unsafe {
SUBSCRIPTION = Some(SubscriptionState::Uninitialized {
ctx,
prepared,
callback,
});
}
}
pub(super) fn subscription() -> Result<&'static mut dyn Subscription, Error> {
unsafe {
let state = std::mem::take(&mut SUBSCRIPTION);
match state {
Some(SubscriptionState::Initialized(_)) => {
SUBSCRIPTION = state; }
Some(SubscriptionState::Uninitialized {
ctx,
prepared,
callback,
}) => {
SUBSCRIPTION = Some(SubscriptionState::Initialized(callback()?));
drop(prepared);
drop(ctx);
}
None => {
return Err(Error {
message: "No active subscription.".to_string(),
extensions: Vec::new(),
});
}
}
let Some(SubscriptionState::Initialized(subscription)) = SUBSCRIPTION.as_mut() else {
unreachable!();
};
Ok(subscription.deref_mut())
}
}
pub(super) fn drop_subscription() {
unsafe {
SUBSCRIPTION = None;
}
}
fn parse_host_level(host_log_level: String) -> String {
let parts: Vec<&str> = host_log_level.split(',').map(|part| part.trim()).collect();
let has_extension_directives = parts.iter().any(|part| part.starts_with("extension="));
parts
.into_iter()
.filter_map(|part| {
if let Some(level) = part.strip_prefix("extension=") {
return Some(Cow::Owned(level.to_string()));
}
if has_extension_directives {
match part {
"trace" | "debug" | "info" | "warn" | "error" => None,
_ => Some(Cow::Borrowed(part)),
}
} else {
Some(Cow::Borrowed(part))
}
})
.collect::<Vec<_>>()
.join(",")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_host_level_with_extension_directive() {
assert_eq!(parse_host_level("extension=debug".to_string()), "debug");
assert_eq!(parse_host_level("extension=info".to_string()), "info");
assert_eq!(parse_host_level("extension=trace".to_string()), "trace");
assert_eq!(parse_host_level("extension=warn".to_string()), "warn");
assert_eq!(parse_host_level("extension=error".to_string()), "error");
}
#[test]
fn test_parse_host_level_with_extension_and_standalone_levels() {
assert_eq!(parse_host_level("extension=debug,info".to_string()), "debug");
assert_eq!(parse_host_level("info,extension=debug".to_string()), "debug");
assert_eq!(parse_host_level("trace,extension=warn,error".to_string()), "warn");
assert_eq!(
parse_host_level("debug,info,extension=error,warn,trace".to_string()),
"error"
);
}
#[test]
fn test_parse_host_level_with_extension_and_module_directives() {
assert_eq!(
parse_host_level("extension=debug,my_module=info".to_string()),
"debug,my_module=info"
);
assert_eq!(
parse_host_level("my_module=info,extension=debug".to_string()),
"my_module=info,debug"
);
assert_eq!(
parse_host_level("extension=warn,crate1=debug,crate2=info".to_string()),
"warn,crate1=debug,crate2=info"
);
}
#[test]
fn test_parse_host_level_without_extension_directive() {
assert_eq!(parse_host_level("debug".to_string()), "debug");
assert_eq!(parse_host_level("info,warn".to_string()), "info,warn");
assert_eq!(
parse_host_level("debug,my_module=info".to_string()),
"debug,my_module=info"
);
assert_eq!(
parse_host_level("trace,crate1=debug,crate2=info,error".to_string()),
"trace,crate1=debug,crate2=info,error"
);
}
#[test]
fn test_parse_host_level_with_whitespace() {
assert_eq!(parse_host_level("extension=debug, info".to_string()), "debug");
assert_eq!(parse_host_level(" extension=debug , info ".to_string()), "debug");
assert_eq!(
parse_host_level("extension=debug, my_module=info".to_string()),
"debug,my_module=info"
);
assert_eq!(
parse_host_level(" debug , my_module=info ".to_string()),
"debug,my_module=info"
);
}
#[test]
fn test_parse_host_level_edge_cases() {
assert_eq!(parse_host_level("".to_string()), "");
assert_eq!(parse_host_level(",,,".to_string()), ",,,");
assert_eq!(
parse_host_level("extension=debug,extension=info".to_string()),
"debug,info"
);
assert_eq!(parse_host_level("extension=".to_string()), "");
assert_eq!(
parse_host_level("extension=debug,my-module=info,my::module=warn".to_string()),
"debug,my-module=info,my::module=warn"
);
}
#[test]
fn test_parse_host_level_preserves_order() {
assert_eq!(
parse_host_level("a=1,extension=debug,b=2,c=3".to_string()),
"a=1,debug,b=2,c=3"
);
assert_eq!(
parse_host_level("x=info,y=warn,extension=error,z=trace".to_string()),
"x=info,y=warn,error,z=trace"
);
}
}