use std::sync::Arc;
use axum::{
Json, Router,
extract::State,
http::{HeaderMap, StatusCode},
response::sse::{Event, Sse},
routing::{get, post},
};
use futures::StreamExt;
use serde::{Deserialize, Serialize};
use crate::agent::Agent;
use crate::error::Result;
struct AppState {
agent: Agent,
api_key: Option<String>,
}
#[derive(Deserialize)]
pub struct PromptRequest {
pub input: String,
}
#[derive(Serialize)]
pub struct PromptResponse {
pub text: String,
pub iterations: usize,
pub input_tokens: u32,
pub output_tokens: u32,
pub cost: f64,
}
pub struct AgentServer {
agent: Agent,
bind_addr: String,
api_key: Option<String>,
}
impl AgentServer {
pub fn new(agent: Agent) -> Self {
Self {
agent,
bind_addr: "0.0.0.0:8080".to_string(),
api_key: None,
}
}
pub fn bind(mut self, addr: impl Into<String>) -> Self {
self.bind_addr = addr.into();
self
}
pub fn api_key(mut self, key: impl Into<String>) -> Self {
self.api_key = Some(key.into());
self
}
pub async fn serve(self) -> Result<()> {
let state = Arc::new(AppState {
agent: self.agent,
api_key: self.api_key,
});
let app = Router::new()
.route("/health", get(health))
.route("/prompt", post(prompt_handler))
.route("/prompt/stream", post(prompt_stream_handler))
.with_state(state);
let listener = tokio::net::TcpListener::bind(&self.bind_addr)
.await
.map_err(|e| crate::error::DaimonError::Other(format!("bind error: {e}")))?;
tracing::info!(addr = %self.bind_addr, "agent server listening");
axum::serve(listener, app)
.await
.map_err(|e| crate::error::DaimonError::Other(format!("server error: {e}")))?;
Ok(())
}
}
fn check_api_key(
state: &AppState,
headers: &HeaderMap,
) -> std::result::Result<(), (StatusCode, String)> {
let Some(expected) = &state.api_key else {
return Ok(());
};
let provided = headers
.get("authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "))
.or_else(|| {
headers
.get("x-api-key")
.and_then(|v| v.to_str().ok())
});
match provided {
Some(key) if key == expected.as_str() => Ok(()),
_ => Err((StatusCode::UNAUTHORIZED, "invalid or missing API key".to_string())),
}
}
async fn health() -> &'static str {
"ok"
}
async fn prompt_handler(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(req): Json<PromptRequest>,
) -> std::result::Result<Json<PromptResponse>, (StatusCode, String)> {
check_api_key(&state, &headers)?;
let response = state
.agent
.prompt(&req.input)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
Ok(Json(PromptResponse {
text: response.final_text,
iterations: response.iterations,
input_tokens: response.usage.input_tokens,
output_tokens: response.usage.output_tokens,
cost: response.cost,
}))
}
async fn prompt_stream_handler(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(req): Json<PromptRequest>,
) -> std::result::Result<
Sse<impl futures::Stream<Item = std::result::Result<Event, axum::Error>>>,
(StatusCode, String),
> {
check_api_key(&state, &headers)?;
let stream = state
.agent
.prompt_stream(&req.input)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
let sse_stream = stream.map(|event| {
let event = event.map_err(|e| axum::Error::new(e))?;
let data = serde_json::to_string(&format!("{event:?}")).unwrap_or_default();
Ok(Event::default().data(data))
});
Ok(Sse::new(sse_stream))
}