use std::sync::Arc;
use async_trait::async_trait;
use fusillade::request::{Canceled, Claimed, Request, RequestCompletionResult};
use fusillade::{CancellationFuture, DefaultRequestProcessor, PoolProvider as FusilladePool, RequestProcessor, ShouldRetry, Storage};
use onwards::LoopConfig;
use onwards::client::HttpClient;
use onwards::traits::ToolExecutor;
use crate::responses::loop_http_client::ResponseLoopHttpClient;
use crate::responses::store::FusilladeResponseStore;
pub struct DwctlRequestProcessor<P, T>
where
P: FusilladePool + Clone + Send + Sync + 'static,
T: ToolExecutor + 'static,
{
pub response_store: Arc<FusilladeResponseStore<P>>,
pub tool_executor: Arc<T>,
pub http_client: Arc<dyn HttpClient + Send + Sync>,
pub loop_config: LoopConfig,
pub tool_resolver: Option<Arc<dyn DaemonToolResolver>>,
pub default: DefaultRequestProcessor,
}
#[async_trait]
pub trait DaemonToolResolver: Send + Sync {
async fn resolve(&self, api_key: &str, model_alias: &str) -> Result<Option<crate::tool_executor::ResolvedToolSet>, anyhow::Error>;
}
pub struct DbToolResolver {
pub pool: sqlx::PgPool,
}
#[async_trait]
impl DaemonToolResolver for DbToolResolver {
async fn resolve(&self, api_key: &str, model_alias: &str) -> Result<Option<crate::tool_executor::ResolvedToolSet>, anyhow::Error> {
crate::tool_injection::resolve_tools_for_request(&self.pool, api_key, Some(model_alias)).await
}
}
impl<P, T> DwctlRequestProcessor<P, T>
where
P: FusilladePool + Clone + Send + Sync + 'static,
T: ToolExecutor + 'static,
{
pub fn new(
response_store: Arc<FusilladeResponseStore<P>>,
tool_executor: Arc<T>,
http_client: Arc<dyn HttpClient + Send + Sync>,
loop_config: LoopConfig,
) -> Self {
Self {
response_store,
tool_executor,
http_client,
loop_config,
tool_resolver: None,
default: DefaultRequestProcessor,
}
}
pub fn with_tool_resolver(mut self, resolver: Arc<dyn DaemonToolResolver>) -> Self {
self.tool_resolver = Some(resolver);
self
}
}
#[async_trait]
impl<S, H, P, T> RequestProcessor<S, H> for DwctlRequestProcessor<P, T>
where
S: Storage + Sync,
H: fusillade::HttpClient + 'static,
P: FusilladePool + Clone + Send + Sync + 'static,
T: ToolExecutor + 'static,
{
async fn process(
&self,
request: Request<Claimed>,
http: H,
storage: &S,
should_retry: ShouldRetry,
cancellation: CancellationFuture,
) -> fusillade::Result<RequestCompletionResult> {
if request.data.path != "/v1/responses" {
return self.default.process(request, http, storage, should_retry, cancellation).await;
}
let loop_client = ResponseLoopHttpClient {
response_store: self.response_store.clone(),
tool_executor: self.tool_executor.clone(),
inner_http: self.http_client.clone(),
tool_resolver: self.tool_resolver.clone(),
loop_config: self.loop_config,
};
let processing = request.process(loop_client, storage).await?;
processing.complete(storage, |resp| should_retry(resp), cancellation).await
}
}
#[allow(dead_code)]
fn _smoke(_c: Canceled) {}