use std::sync::Arc;
use async_trait::async_trait;
use axum::response::sse::{Event, KeepAlive, Sse};
use fusillade::ReqwestHttpClient;
use futures::stream::Stream;
use onwards::traits::RequestContext;
use onwards::{EventSink, EventSinkError, LoopConfig, LoopError, LoopEvent, MultiStepStore, UpstreamTarget};
use serde_json::Value;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use crate::inference::store::FusilladeResponseStore;
use crate::inference::tools::HttpToolExecutor;
const SSE_CHANNEL_BUFFER: usize = 256;
pub struct SseEventSink {
tx: mpsc::Sender<Result<Event, axum::Error>>,
}
impl SseEventSink {
pub fn new(tx: mpsc::Sender<Result<Event, axum::Error>>) -> Self {
Self { tx }
}
}
#[async_trait]
impl EventSink for SseEventSink {
async fn emit(&self, event: LoopEvent) -> Result<(), EventSinkError> {
let data_str = serde_json::to_string(&event.data).map_err(|e| EventSinkError(format!("serialize SSE data: {e}")))?;
let sse_event = Event::default()
.id(event.sequence.to_string())
.event(event.kind.as_str())
.data(data_str);
self.tx
.send(Ok(sse_event))
.await
.map_err(|e| EventSinkError(format!("SSE channel closed: {e}")))
}
}
const FLEX_REPLAY_BUFFER: usize = 16;
pub struct ReplayFrame {
pub event: Option<&'static str>,
pub data: Value,
}
impl ReplayFrame {
pub fn unnamed(data: Value) -> Self {
Self { event: None, data }
}
pub fn named(event: &'static str, data: Value) -> Self {
Self { event: Some(event), data }
}
}
pub async fn flex_stream_response<P, F>(
request_manager: Arc<fusillade::PostgresRequestManager<P, ReqwestHttpClient>>,
flex_input: fusillade::CreateFlexInput,
request_id: uuid::Uuid,
done_sentinel: bool,
render: F,
) -> axum::response::Response
where
P: fusillade::PoolProvider + Clone + Send + Sync + 'static,
F: FnOnce(Result<&fusillade::RequestDetail, &str>) -> Vec<ReplayFrame> + Send + 'static,
{
use axum::response::IntoResponse;
if let Err(e) = fusillade::Storage::create_flex(&*request_manager, flex_input).await {
tracing::error!(error = %e, "Failed to create streaming flex batch in fusillade");
return axum::response::Response::builder()
.status(axum::http::StatusCode::INTERNAL_SERVER_ERROR)
.header("content-type", "application/json")
.body(axum::body::Body::from(
serde_json::json!({
"error": { "message": "Failed to enqueue request", "type": "server_error", "code": 500 }
})
.to_string(),
))
.unwrap();
}
let (tx, rx) = mpsc::channel::<Result<Event, std::convert::Infallible>>(FLEX_REPLAY_BUFFER);
tokio::spawn(async move {
let poll_interval = std::time::Duration::from_millis(500);
let timeout = std::time::Duration::from_secs(3600);
let result = crate::inference::store::poll_until_terminal(&request_manager, request_id, poll_interval, timeout).await;
let frames = match &result {
Ok(detail) => render(Ok(detail)),
Err(e) => {
tracing::error!(error = %e, request_id = %request_id, "Streaming flex poll failed");
render(Err(&e.to_string()))
}
};
for frame in frames {
let mut event = Event::default().data(frame.data.to_string());
if let Some(name) = frame.event {
event = event.event(name);
}
if tx.send(Ok(event)).await.is_err() {
return; }
}
if done_sentinel {
let _ = tx.send(Ok(Event::default().data("[DONE]"))).await;
}
});
Sse::new(ReceiverStream::new(rx)).keep_alive(KeepAlive::default()).into_response()
}
#[allow(clippy::too_many_arguments)]
pub fn run_inline_streaming<P>(
response_store: Arc<FusilladeResponseStore<P>>,
tool_executor: Arc<HttpToolExecutor>,
tool_resolved: Arc<crate::inference::tools::ResolvedToolSet>,
http_client: Arc<ReqwestHttpClient>,
upstream: UpstreamTarget,
loop_config: LoopConfig,
request_id: String,
model_alias: String,
) -> Sse<impl Stream<Item = Result<Event, axum::Error>>>
where
P: fusillade::PoolProvider + Clone + Send + Sync + 'static,
{
let (tx, rx) = mpsc::channel::<Result<Event, axum::Error>>(SSE_CHANNEL_BUFFER);
tokio::spawn(async move {
let sink = SseEventSink::new(tx.clone());
let tool_ctx = RequestContext::new()
.with_model(model_alias)
.with_extension(crate::inference::tools::ResolvedTools(tool_resolved));
let result = onwards::run_response_loop(
&*response_store,
&*tool_executor,
&tool_ctx,
&upstream,
http_client,
Some(&sink),
&request_id,
None,
loop_config,
0,
)
.await;
match &result {
Ok(_) => {
if let Err(e) = persist_terminal_completed(&response_store, &request_id).await {
tracing::warn!(error = %e, "Failed to persist warm-path terminal state");
let _ = tx
.send(Ok(Event::default()
.event("response.failed")
.data(format!("{{\"type\":\"persist_failed\",\"message\":\"{e}\"}}"))))
.await;
}
}
Err(LoopError::Failed(payload)) => {
if let Err(e) = persist_terminal_failed(&response_store, &request_id, payload).await {
tracing::warn!(error = %e, "Failed to persist warm-path failure state");
}
}
Err(other) => {
let payload = serde_json::json!({
"type": "loop_error",
"message": other.to_string(),
});
if let Err(e) = persist_terminal_failed(&response_store, &request_id, &payload).await {
tracing::warn!(error = %e, "Failed to persist warm-path error state");
}
}
}
response_store.unregister_pending(&request_id);
drop(tx);
});
let stream = ReceiverStream::new(rx);
Sse::new(stream).keep_alive(KeepAlive::default())
}
async fn persist_terminal_completed<P>(response_store: &FusilladeResponseStore<P>, request_id: &str) -> Result<(), String>
where
P: fusillade::PoolProvider + Clone + Send + Sync + 'static,
{
let assembled = response_store
.assemble_response(request_id)
.await
.map_err(|e| format!("assemble: {e}"))?;
response_store
.finalize_head_request(request_id, 200, assembled)
.await
.map_err(|e| format!("finalize head: {e}"))
}
#[allow(clippy::too_many_arguments)]
pub async fn run_inline_blocking<P>(
response_store: Arc<FusilladeResponseStore<P>>,
tool_executor: Arc<HttpToolExecutor>,
tool_resolved: Arc<crate::inference::tools::ResolvedToolSet>,
http_client: Arc<ReqwestHttpClient>,
upstream: UpstreamTarget,
loop_config: LoopConfig,
request_id: String,
model_alias: String,
) -> Result<Value, Value>
where
P: fusillade::PoolProvider + Clone + Send + Sync + 'static,
{
let tool_ctx = RequestContext::new()
.with_model(model_alias)
.with_extension(crate::inference::tools::ResolvedTools(tool_resolved));
let result = onwards::run_response_loop(
&*response_store,
&*tool_executor,
&tool_ctx,
&upstream,
http_client,
None,
&request_id,
None,
loop_config,
0,
)
.await;
let outcome = match result {
Ok(_) => {
if let Err(e) = persist_terminal_completed(&response_store, &request_id).await {
tracing::warn!(error = %e, "Failed to persist warm-path-blocking terminal state");
}
response_store
.assemble_response(&request_id)
.await
.map_err(|e| serde_json::json!({"type": "assemble_failed", "message": e.to_string()}))
}
Err(LoopError::Failed(payload)) => {
if let Err(e) = persist_terminal_failed(&response_store, &request_id, &payload).await {
tracing::warn!(error = %e, "Failed to persist warm-path-blocking failure");
}
Err(payload)
}
Err(other) => {
let payload = serde_json::json!({
"type": "loop_error",
"message": other.to_string(),
});
if let Err(e) = persist_terminal_failed(&response_store, &request_id, &payload).await {
tracing::warn!(error = %e, "Failed to persist warm-path-blocking error");
}
Err(payload)
}
};
response_store.unregister_pending(&request_id);
outcome
}
async fn persist_terminal_failed<P>(response_store: &FusilladeResponseStore<P>, request_id: &str, error: &Value) -> Result<(), String>
where
P: fusillade::PoolProvider + Clone + Send + Sync + 'static,
{
response_store
.finalize_head_request(request_id, 500, error.clone())
.await
.map_err(|e| format!("finalize head: {e}"))
}