use std::sync::Arc;
use async_trait::async_trait;
use axum::response::sse::{Event, KeepAlive, Sse};
use futures::stream::Stream;
use onwards::client::HttpClient;
use onwards::traits::RequestContext;
use onwards::{EventSink, EventSinkError, LoopConfig, LoopError, LoopEvent, MultiStepStore, UpstreamTarget};
use serde_json::Value;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use crate::responses::store::FusilladeResponseStore;
use crate::tool_executor::HttpToolExecutor;
const SSE_CHANNEL_BUFFER: usize = 256;
pub struct SseEventSink {
tx: mpsc::Sender<Result<Event, axum::Error>>,
}
impl SseEventSink {
pub fn new(tx: mpsc::Sender<Result<Event, axum::Error>>) -> Self {
Self { tx }
}
}
#[async_trait]
impl EventSink for SseEventSink {
async fn emit(&self, event: LoopEvent) -> Result<(), EventSinkError> {
let data_str = serde_json::to_string(&event.data).map_err(|e| EventSinkError(format!("serialize SSE data: {e}")))?;
let sse_event = Event::default()
.id(event.sequence.to_string())
.event(event.kind.as_str())
.data(data_str);
self.tx
.send(Ok(sse_event))
.await
.map_err(|e| EventSinkError(format!("SSE channel closed: {e}")))
}
}
#[allow(clippy::too_many_arguments)]
pub fn run_inline_streaming<P>(
response_store: Arc<FusilladeResponseStore<P>>,
tool_executor: Arc<HttpToolExecutor>,
tool_resolved: Arc<crate::tool_executor::ResolvedToolSet>,
http_client: Arc<dyn HttpClient + Send + Sync>,
upstream: UpstreamTarget,
loop_config: LoopConfig,
request_id: String,
model_alias: String,
) -> Sse<impl Stream<Item = Result<Event, axum::Error>>>
where
P: fusillade::PoolProvider + Clone + Send + Sync + 'static,
{
let (tx, rx) = mpsc::channel::<Result<Event, axum::Error>>(SSE_CHANNEL_BUFFER);
tokio::spawn(async move {
let sink = SseEventSink::new(tx.clone());
let tool_ctx = RequestContext::new()
.with_model(model_alias)
.with_extension(crate::tool_executor::ResolvedTools(tool_resolved));
let result = onwards::run_response_loop(
&*response_store,
&*tool_executor,
&tool_ctx,
&upstream,
http_client,
Some(&sink),
&request_id,
None,
loop_config,
0,
)
.await;
match &result {
Ok(_) => {
if let Err(e) = persist_terminal_completed(&response_store, &request_id).await {
tracing::warn!(error = %e, "Failed to persist warm-path terminal state");
let _ = tx
.send(Ok(Event::default()
.event("response.failed")
.data(format!("{{\"type\":\"persist_failed\",\"message\":\"{e}\"}}"))))
.await;
}
}
Err(LoopError::Failed(payload)) => {
if let Err(e) = persist_terminal_failed(&response_store, &request_id, payload).await {
tracing::warn!(error = %e, "Failed to persist warm-path failure state");
}
}
Err(other) => {
let payload = serde_json::json!({
"type": "loop_error",
"message": other.to_string(),
});
if let Err(e) = persist_terminal_failed(&response_store, &request_id, &payload).await {
tracing::warn!(error = %e, "Failed to persist warm-path error state");
}
}
}
response_store.unregister_pending(&request_id);
drop(tx);
});
let stream = ReceiverStream::new(rx);
Sse::new(stream).keep_alive(KeepAlive::default())
}
async fn persist_terminal_completed<P>(response_store: &FusilladeResponseStore<P>, request_id: &str) -> Result<(), String>
where
P: fusillade::PoolProvider + Clone + Send + Sync + 'static,
{
let assembled = response_store
.assemble_response(request_id)
.await
.map_err(|e| format!("assemble: {e}"))?;
response_store
.finalize_head_request(request_id, 200, assembled)
.await
.map_err(|e| format!("finalize head: {e}"))
}
#[allow(clippy::too_many_arguments)]
pub async fn run_inline_blocking<P>(
response_store: Arc<FusilladeResponseStore<P>>,
tool_executor: Arc<HttpToolExecutor>,
tool_resolved: Arc<crate::tool_executor::ResolvedToolSet>,
http_client: Arc<dyn HttpClient + Send + Sync>,
upstream: UpstreamTarget,
loop_config: LoopConfig,
request_id: String,
model_alias: String,
) -> Result<Value, Value>
where
P: fusillade::PoolProvider + Clone + Send + Sync + 'static,
{
let tool_ctx = RequestContext::new()
.with_model(model_alias)
.with_extension(crate::tool_executor::ResolvedTools(tool_resolved));
let result = onwards::run_response_loop(
&*response_store,
&*tool_executor,
&tool_ctx,
&upstream,
http_client,
None,
&request_id,
None,
loop_config,
0,
)
.await;
let outcome = match result {
Ok(_) => {
if let Err(e) = persist_terminal_completed(&response_store, &request_id).await {
tracing::warn!(error = %e, "Failed to persist warm-path-blocking terminal state");
}
response_store
.assemble_response(&request_id)
.await
.map_err(|e| serde_json::json!({"type": "assemble_failed", "message": e.to_string()}))
}
Err(LoopError::Failed(payload)) => {
if let Err(e) = persist_terminal_failed(&response_store, &request_id, &payload).await {
tracing::warn!(error = %e, "Failed to persist warm-path-blocking failure");
}
Err(payload)
}
Err(other) => {
let payload = serde_json::json!({
"type": "loop_error",
"message": other.to_string(),
});
if let Err(e) = persist_terminal_failed(&response_store, &request_id, &payload).await {
tracing::warn!(error = %e, "Failed to persist warm-path-blocking error");
}
Err(payload)
}
};
response_store.unregister_pending(&request_id);
outcome
}
async fn persist_terminal_failed<P>(response_store: &FusilladeResponseStore<P>, request_id: &str, error: &Value) -> Result<(), String>
where
P: fusillade::PoolProvider + Clone + Send + Sync + 'static,
{
response_store
.finalize_head_request(request_id, 500, error.clone())
.await
.map_err(|e| format!("finalize head: {e}"))
}