openai_agents_rust/
mcp_server.rs

1use axum::{
2    Router,
3    extract::{Path, State},
4    response::{
5        IntoResponse, Json,
6        sse::{self, Sse},
7    },
8    routing::{get, post},
9};
10use serde::{Deserialize, Serialize};
11use std::{net::SocketAddr, sync::Arc};
12use tokio::net::TcpListener;
13use tokio::sync::broadcast::{Receiver, Sender};
14
15use crate::{
16    client::OpenAiClient,
17    config::Config,
18    error::AgentError,
19    model::{self, Model},
20    plugin::loader::PluginRegistry,
21};
22
23/// Shared application state for the MCP server.
24#[derive(Clone)]
25pub struct AppState {
26    pub config: Arc<Config>,
27    pub client: Arc<OpenAiClient>,
28    pub plugins: Arc<PluginRegistry>,
29    pub broadcaster: Sender<String>,
30}
31
32/// Request payload for the `/run` endpoint.
33#[derive(Debug, Deserialize)]
34pub struct RunRequest {
35    pub model: String,
36    pub prompt: String,
37}
38
39/// Simple response wrapper.
40#[derive(Debug, Serialize)]
41pub struct RunResponse {
42    pub result: String,
43}
44
45/// Handler for `/run` – forwards the prompt to the selected model.
46#[axum::debug_handler]
47async fn run_handler(
48    State(state): State<AppState>,
49    Json(payload): Json<RunRequest>,
50) -> Result<Json<RunResponse>, AgentError> {
51    // Instantiate the requested model.
52    let model: Box<dyn Model> = match payload.model.as_str() {
53        "openai_chat" => Box::new(model::openai_chat::OpenAiChat::new((*state.config).clone())),
54        "openai_realtime" => Box::new(model::openai_realtime::OpenAiRealtime::new(
55            (*state.config).clone(),
56        )),
57        "litellm" => Box::new(model::litellm::LiteLLM::new((*state.config).clone())),
58        _ => {
59            return Err(AgentError::Other(format!(
60                "Unknown model {}",
61                payload.model
62            )));
63        }
64    };
65
66    let result = model.generate(&payload.prompt).await?;
67    // Broadcast the result to any SSE listeners.
68    let _ = state.broadcaster.send(result.clone());
69
70    Ok(Json(RunResponse { result }))
71}
72
73/// Handler for `/status/:session_id` – returns basic liveness and session echo.
74async fn status_handler(
75    Path(session_id): Path<String>,
76    State(state): State<AppState>,
77) -> impl IntoResponse {
78    let subscriber_count = state.broadcaster.receiver_count();
79    Json(serde_json::json!({
80        "status": "running",
81        "session_id": session_id,
82        "subscribers": subscriber_count,
83    }))
84}
85
86/// Handler for `/events/:session_id` – Server‑Sent Events stream.
87async fn events_handler(
88    Path(_session_id): Path<String>,
89    State(state): State<AppState>,
90) -> Sse<impl futures_core::Stream<Item = Result<sse::Event, std::convert::Infallible>>> {
91    let mut rx: Receiver<String> = state.broadcaster.subscribe();
92    let stream = async_stream::stream! {
93        while let Ok(msg) = rx.recv().await {
94            yield Ok(sse::Event::default().data(msg));
95        }
96    };
97    Sse::new(stream)
98}
99
100/// Build the Axum router.
101pub fn router(state: AppState) -> Router {
102    Router::new()
103        .route("/run", post(run_handler))
104        .route("/status/:session_id", get(status_handler))
105        .route("/events/:session_id", get(events_handler))
106        .with_state(state)
107}
108
109/// Start the MCP server – called from `main.rs`.
110pub async fn start_server(state: AppState, addr: SocketAddr) -> Result<(), AgentError> {
111    let app = router(state);
112    let listener = TcpListener::bind(addr)
113        .await
114        .map_err(|e| AgentError::Other(e.to_string()))?;
115    axum::serve(listener, app.into_make_service())
116        .await
117        .map_err(|e| AgentError::Other(e.to_string()))
118}