use std::sync::Arc;
use async_trait::async_trait;
use fusillade::request::{Canceled, Claimed, Completed, Failed, Request, RequestCompletionResult};
use fusillade::{
CancellationFuture, DefaultRequestProcessor, FailureReason, PoolProvider as FusilladePool, RequestProcessor, ShouldRetry, Storage,
};
use onwards::client::HttpClient;
use onwards::traits::{RequestContext, ToolExecutor};
use onwards::{LoopConfig, LoopError, MultiStepStore, UpstreamTarget};
use crate::responses::store::{FusilladeResponseStore, PendingResponseInput};
use crate::tool_executor::ResolvedTools;
pub struct DwctlRequestProcessor<P, T>
where
P: FusilladePool + Clone + Send + Sync + 'static,
T: ToolExecutor + 'static,
{
pub response_store: Arc<FusilladeResponseStore<P>>,
pub tool_executor: Arc<T>,
pub http_client: Arc<dyn HttpClient + Send + Sync>,
pub loop_config: LoopConfig,
pub tool_resolver: Option<Arc<dyn DaemonToolResolver>>,
pub default: DefaultRequestProcessor,
}
#[async_trait]
pub trait DaemonToolResolver: Send + Sync {
async fn resolve(&self, api_key: &str, model_alias: &str) -> Result<Option<crate::tool_executor::ResolvedToolSet>, anyhow::Error>;
}
pub struct DbToolResolver {
pub pool: sqlx::PgPool,
}
#[async_trait]
impl DaemonToolResolver for DbToolResolver {
async fn resolve(&self, api_key: &str, model_alias: &str) -> Result<Option<crate::tool_executor::ResolvedToolSet>, anyhow::Error> {
crate::tool_injection::resolve_tools_for_request(&self.pool, api_key, Some(model_alias)).await
}
}
impl<P, T> DwctlRequestProcessor<P, T>
where
P: FusilladePool + Clone + Send + Sync + 'static,
T: ToolExecutor + 'static,
{
pub fn new(
response_store: Arc<FusilladeResponseStore<P>>,
tool_executor: Arc<T>,
http_client: Arc<dyn HttpClient + Send + Sync>,
loop_config: LoopConfig,
) -> Self {
Self {
response_store,
tool_executor,
http_client,
loop_config,
tool_resolver: None,
default: DefaultRequestProcessor,
}
}
pub fn with_tool_resolver(mut self, resolver: Arc<dyn DaemonToolResolver>) -> Self {
self.tool_resolver = Some(resolver);
self
}
}
#[async_trait]
impl<S, H, P, T> RequestProcessor<S, H> for DwctlRequestProcessor<P, T>
where
S: Storage + Sync,
H: fusillade::HttpClient + 'static,
P: FusilladePool + Clone + Send + Sync + 'static,
T: ToolExecutor + 'static,
{
async fn process(
&self,
request: Request<Claimed>,
http: H,
storage: &S,
should_retry: ShouldRetry,
cancellation: CancellationFuture,
) -> fusillade::Result<RequestCompletionResult> {
if request.data.path != "/v1/responses" {
return self.default.process(request, http, storage, should_retry, cancellation).await;
}
let _cancellation_holder = cancellation;
let _should_retry_unused = should_retry;
let _http_unused = http;
let _storage_unused = storage;
let request_id = request.data.id.0.to_string();
let upstream = UpstreamTarget {
url: {
let base = request.data.endpoint.trim_end_matches('/');
format!("{base}/v1/chat/completions")
},
api_key: if request.data.api_key.is_empty() {
None
} else {
Some(request.data.api_key.clone())
},
};
let mut tool_ctx = RequestContext::new().with_model(request.data.model.clone());
let mut resolved_tool_names = std::collections::HashSet::new();
if let Some(resolver) = &self.tool_resolver {
match resolver.resolve(&request.data.api_key, &request.data.model).await {
Ok(Some(resolved)) => {
resolved_tool_names = resolved.tools.keys().cloned().collect();
tool_ctx = tool_ctx.with_extension(ResolvedTools(Arc::new(resolved)));
}
Ok(None) => {
tracing::debug!(
request_id = %request.data.id,
model = %request.data.model,
"no tools resolved for daemon-driven /v1/responses request"
);
}
Err(e) => {
tracing::warn!(
error = %e,
request_id = %request.data.id,
"tool resolution failed for daemon path; running loop with no tools"
);
}
}
}
let pending = PendingResponseInput {
body: request.data.body.clone(),
api_key: if request.data.api_key.is_empty() {
None
} else {
Some(request.data.api_key.clone())
},
created_by: if request.data.created_by.is_empty() {
None
} else {
Some(request.data.created_by.clone())
},
base_url: request.data.endpoint.clone(),
resolved_tool_names,
};
if let Err(e) = self.response_store.register_pending_with_id(request.data.id.0, pending) {
return Err(fusillade::FusilladeError::Other(anyhow::anyhow!(
"register pending input for daemon-driven /v1/responses: {e}"
)));
}
let cleanup_store = self.response_store.clone();
let cleanup_id = request_id.clone();
let _pending_guard = scopeguard::guard((), move |_| {
cleanup_store.unregister_pending(&cleanup_id);
});
let result = onwards::run_response_loop(
&*self.response_store,
&*self.tool_executor,
&tool_ctx,
&upstream,
self.http_client.clone(),
None,
&request_id,
None,
self.loop_config,
0,
)
.await;
match result {
Ok(_final_payload) => {
let assembled = self
.response_store
.assemble_response(&request_id)
.await
.map_err(|e| fusillade::FusilladeError::Other(anyhow::anyhow!("assemble_response after loop: {e}")))?;
let body = serde_json::to_string(&assembled)
.map_err(|e| fusillade::FusilladeError::Other(anyhow::anyhow!("serialize assembled response: {e}")))?;
let completed = Request {
data: request.data.clone(),
state: Completed {
response_status: 200,
response_body: body,
claimed_at: request.state.claimed_at,
started_at: chrono::Utc::now(),
completed_at: chrono::Utc::now(),
routed_model: request.data.model.clone(),
},
};
storage.persist(&completed).await?;
Ok(RequestCompletionResult::Completed(completed))
}
Err(LoopError::Failed(payload)) => {
let body = serde_json::to_string(&payload).unwrap_or_default();
let failed = Request {
data: request.data.clone(),
state: Failed {
reason: FailureReason::NonRetriableHttpStatus { status: 500, body },
failed_at: chrono::Utc::now(),
retry_attempt: request.state.retry_attempt,
batch_expires_at: request.state.batch_expires_at,
routed_model: request.data.model.clone(),
},
};
storage.persist(&failed).await?;
Ok(RequestCompletionResult::Failed(failed))
}
Err(other) => {
let failed = Request {
data: request.data.clone(),
state: Failed {
reason: FailureReason::NonRetriableHttpStatus {
status: 500,
body: format!("multi-step loop error: {other}"),
},
failed_at: chrono::Utc::now(),
retry_attempt: request.state.retry_attempt,
batch_expires_at: request.state.batch_expires_at,
routed_model: request.data.model.clone(),
},
};
storage.persist(&failed).await?;
Ok(RequestCompletionResult::Failed(failed))
}
}
}
}
#[allow(dead_code)]
fn _smoke(_c: Canceled) {}