use std::sync::Arc;
use async_trait::async_trait;
use fusillade::request::{Canceled, Claimed, Request, RequestCompletionResult};
use fusillade::{
CancellationFuture, DefaultRequestProcessor, PoolProvider as FusilladePool, RequestProcessor, ReqwestHttpClient, ShouldRetry, Storage,
};
use onwards::LoopConfig;
use onwards::traits::ToolExecutor;
use crate::responses::loop_http_client::ResponseLoopHttpClient;
use crate::responses::store::FusilladeResponseStore;
pub struct DwctlRequestProcessor<P, T>
where
P: FusilladePool + Clone + Send + Sync + 'static,
T: ToolExecutor + 'static,
{
pub response_store: Arc<FusilladeResponseStore<P>>,
pub tool_executor: Arc<T>,
pub http_client: Arc<ReqwestHttpClient>,
pub loop_config: LoopConfig,
pub tool_resolver: Option<Arc<dyn DaemonToolResolver>>,
pub image_normalizer: Option<Arc<dyn crate::image_normalizer::ImageNormalizer>>,
pub dispatch_ttl: std::time::Duration,
pub default: DefaultRequestProcessor,
}
#[async_trait]
pub trait DaemonToolResolver: Send + Sync {
async fn resolve(&self, api_key: &str, model_alias: &str) -> Result<Option<crate::tool_executor::ResolvedToolSet>, anyhow::Error>;
}
pub struct DbToolResolver {
pub pool: sqlx::PgPool,
}
#[async_trait]
impl DaemonToolResolver for DbToolResolver {
async fn resolve(&self, api_key: &str, model_alias: &str) -> Result<Option<crate::tool_executor::ResolvedToolSet>, anyhow::Error> {
crate::tool_injection::resolve_tools_for_request(&self.pool, api_key, Some(model_alias)).await
}
}
impl<P, T> DwctlRequestProcessor<P, T>
where
P: FusilladePool + Clone + Send + Sync + 'static,
T: ToolExecutor + 'static,
{
pub fn new(
response_store: Arc<FusilladeResponseStore<P>>,
tool_executor: Arc<T>,
http_client: Arc<ReqwestHttpClient>,
loop_config: LoopConfig,
) -> Self {
Self {
response_store,
tool_executor,
http_client,
loop_config,
tool_resolver: None,
image_normalizer: None,
dispatch_ttl: std::time::Duration::from_secs(1800),
default: DefaultRequestProcessor,
}
}
pub fn with_image_normalizer(
mut self,
normalizer: Arc<dyn crate::image_normalizer::ImageNormalizer>,
ttl: std::time::Duration,
) -> Self {
self.image_normalizer = Some(normalizer);
self.dispatch_ttl = ttl;
self
}
pub fn with_tool_resolver(mut self, resolver: Arc<dyn DaemonToolResolver>) -> Self {
self.tool_resolver = Some(resolver);
self
}
}
#[async_trait]
impl<S, H, P, T> RequestProcessor<S, H> for DwctlRequestProcessor<P, T>
where
S: Storage + Sync,
H: fusillade::HttpClient + 'static,
P: FusilladePool + Clone + Send + Sync + 'static,
T: ToolExecutor + 'static,
{
async fn process(
&self,
mut request: Request<Claimed>,
http: H,
storage: &S,
should_retry: ShouldRetry,
cancellation: CancellationFuture,
) -> fusillade::Result<RequestCompletionResult> {
if let Some(normalizer) = self.image_normalizer.clone() {
let ttl = self.dispatch_ttl;
let mut body_value: serde_json::Value = serde_json::from_str(&request.data.body).map_err(|e| {
fusillade::FusilladeError::ValidationError(format!(
"JIT image signing: request body is not valid JSON ({e}); refusing to dispatch with unresolved tokens"
))
})?;
let result = crate::image_normalizer::walker::substitute_with(
&mut body_value,
crate::image_normalizer::Mode::TokensOnly,
|maybe_token| {
let normalizer = Arc::clone(&normalizer);
async move {
let token: crate::image_normalizer::ImageToken = maybe_token
.parse()
.map_err(|e: crate::image_normalizer::TokenParseError| format!("invalid dw-img token: {e}"))?;
let signed = normalizer.sign(token, ttl).await.map_err(|e| format!("sign failed: {e}"))?;
Ok::<String, String>(signed.url)
}
},
)
.await;
match result {
Ok(count) if count > 0 => match serde_json::to_string(&body_value) {
Ok(new_body) => request.data.body = new_body,
Err(e) => {
return Err(fusillade::FusilladeError::Other(anyhow::anyhow!(
"re-serialise body after JIT signing: {e}"
)));
}
},
Ok(_) => {} Err(e) => {
return Err(fusillade::FusilladeError::Other(anyhow::anyhow!(
"JIT image-URL signing failed: {e}"
)));
}
}
}
if request.data.path != "/v1/responses" {
return self.default.process(request, http, storage, should_retry, cancellation).await;
}
let has_tools = serde_json::from_str::<serde_json::Value>(&request.data.body)
.ok()
.as_ref()
.and_then(|v| v.get("tools"))
.and_then(|v| v.as_array())
.is_some_and(|a| !a.is_empty());
if !has_tools {
return self.default.process(request, http, storage, should_retry, cancellation).await;
}
let loop_client = ResponseLoopHttpClient {
response_store: self.response_store.clone(),
tool_executor: self.tool_executor.clone(),
inner_http: self.http_client.clone(),
tool_resolver: self.tool_resolver.clone(),
loop_config: self.loop_config,
};
let processing = request.process(loop_client, storage).await?;
processing.complete(storage, |resp| should_retry(resp), cancellation).await
}
}
#[allow(dead_code)]
fn _smoke(_c: Canceled) {}