use std::convert::Infallible;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use axum::Json;
use axum::http::StatusCode;
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::{IntoResponse, Response};
use futures::Stream;
use futures::stream::StreamExt;
use serde::{Deserialize, Serialize};
use crate::council::builder::build_agents;
use crate::council::config::{CouncilConfig, EmbedderConfig};
use crate::council::embedder::Embedder;
use crate::council::event::CouncilEvent;
use crate::council::local_embedder::LocalGgufEmbedder;
use crate::council::{Council, CouncilParams};
#[derive(Debug, Deserialize)]
pub struct CouncilRequest {
pub prompt: String,
pub config_toml: String,
}
#[derive(Debug, Serialize)]
pub struct CouncilCompletion {
pub id: String,
pub object: &'static str,
pub created: u64,
pub answer: String,
}
#[derive(Debug, Serialize)]
pub struct ErrorBody {
pub error: ErrorMessage,
}
#[derive(Debug, Serialize)]
pub struct ErrorMessage {
pub message: String,
#[serde(rename = "type")]
pub kind: &'static str,
}
fn bad_request(msg: String) -> Response {
(
StatusCode::BAD_REQUEST,
Json(ErrorBody {
error: ErrorMessage {
message: msg,
kind: "invalid_request_error",
},
}),
)
.into_response()
}
fn upstream_error(msg: String) -> Response {
(
StatusCode::BAD_GATEWAY,
Json(ErrorBody {
error: ErrorMessage {
message: msg,
kind: "upstream_error",
},
}),
)
.into_response()
}
pub async fn council_completions(Json(req): Json<CouncilRequest>) -> Response {
let (council, _embedder) = match build_council(&req).await {
Ok(v) => v,
Err(e) => return e,
};
let prompt = req.prompt;
match council.deliberate(&prompt).await {
Ok(answer) => Json(CouncilCompletion {
id: format!("council-{}", now_secs()),
object: "council.completion",
created: now_secs(),
answer,
})
.into_response(),
Err(e) => upstream_error(e.to_string()),
}
}
pub async fn council_transcript(
Json(req): Json<CouncilRequest>,
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, Response> {
let (council, _embedder) = build_council(&req).await?;
let prompt = req.prompt;
let inner: std::pin::Pin<Box<dyn Stream<Item = CouncilEvent> + Send>> =
Box::pin(council.deliberate_stream(&prompt));
let stream = futures::stream::unfold(
(council, inner),
|(council, mut inner)| async move {
inner.next().await.map(|ev| {
let json = serde_json::to_string(&ev).unwrap_or_else(|e| {
format!(
r#"{{"type":"council_failed","error":{{"code":"internal","message":"{e}"}}}}"#
)
});
(
Ok::<_, Infallible>(Event::default().data(json)),
(council, inner),
)
})
},
);
Ok(Sse::new(stream).keep_alive(KeepAlive::default()))
}
async fn build_council(
req: &CouncilRequest,
) -> Result<(Council, Arc<dyn Embedder>), Response> {
let cfg = CouncilConfig::from_toml_str(&req.config_toml)
.map_err(|e| bad_request(format!("bad council config: {e}")))?;
let (experts, synthesizer) = build_agents(&cfg)
.await
.map_err(|e| upstream_error(e.to_string()))?;
let embedder: Arc<dyn Embedder> = match &cfg.embedder {
EmbedderConfig::LocalGguf { path } => {
let e = LocalGgufEmbedder::load(path)
.map_err(|e| upstream_error(format!("embedder load: {e}")))?;
Arc::new(e)
}
};
let council = Council::new(
experts,
synthesizer,
embedder.clone(),
CouncilParams::from_config(&cfg),
);
Ok((council, embedder))
}
fn now_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or_default()
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use axum::http::Request;
use tower::ServiceExt;
#[tokio::test]
async fn completions_rejects_bad_config() {
let app = axum::Router::new()
.route("/v1/council/completions", axum::routing::post(council_completions));
let body = serde_json::to_string(&serde_json::json!({
"prompt": "hi",
"config_toml": "not-valid-toml {}{}",
}))
.unwrap();
let resp = app
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/council/completions")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn completions_rejects_missing_synthesizer() {
let app = axum::Router::new()
.route("/v1/council/completions", axum::routing::post(council_completions));
let cfg = r#"
[embedder]
kind = "local_gguf"
path = "/nonexistent.gguf"
[[agent]]
role = "expert"
endpoint = "grpc://a:1"
model = "m1"
timeout_ms = 1000
[[agent]]
role = "expert"
endpoint = "grpc://b:1"
model = "m2"
timeout_ms = 1000
"#;
let body = serde_json::to_string(&serde_json::json!({
"prompt": "hi",
"config_toml": cfg,
}))
.unwrap();
let resp = app
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/council/completions")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
}