openai_agents_rust/
mcp_server.rs1use axum::{
2 Router,
3 extract::{Path, State},
4 response::{
5 IntoResponse, Json,
6 sse::{self, Sse},
7 },
8 routing::{get, post},
9};
10use serde::{Deserialize, Serialize};
11use std::{net::SocketAddr, sync::Arc};
12use tokio::net::TcpListener;
13use tokio::sync::broadcast::{Receiver, Sender};
14
15use crate::{
16 client::OpenAiClient,
17 config::Config,
18 error::AgentError,
19 model::{self, Model},
20 plugin::loader::PluginRegistry,
21};
22
23#[derive(Clone)]
25pub struct AppState {
26 pub config: Arc<Config>,
27 pub client: Arc<OpenAiClient>,
28 pub plugins: Arc<PluginRegistry>,
29 pub broadcaster: Sender<String>,
30}
31
32#[derive(Debug, Deserialize)]
34pub struct RunRequest {
35 pub model: String,
36 pub prompt: String,
37}
38
39#[derive(Debug, Serialize)]
41pub struct RunResponse {
42 pub result: String,
43}
44
45#[axum::debug_handler]
47async fn run_handler(
48 State(state): State<AppState>,
49 Json(payload): Json<RunRequest>,
50) -> Result<Json<RunResponse>, AgentError> {
51 let model: Box<dyn Model> = match payload.model.as_str() {
53 "openai_chat" => Box::new(model::openai_chat::OpenAiChat::new((*state.config).clone())),
54 "openai_realtime" => Box::new(model::openai_realtime::OpenAiRealtime::new(
55 (*state.config).clone(),
56 )),
57 "litellm" => Box::new(model::litellm::LiteLLM::new((*state.config).clone())),
58 _ => {
59 return Err(AgentError::Other(format!(
60 "Unknown model {}",
61 payload.model
62 )));
63 }
64 };
65
66 let result = model.generate(&payload.prompt).await?;
67 let _ = state.broadcaster.send(result.clone());
69
70 Ok(Json(RunResponse { result }))
71}
72
73async fn status_handler(
75 Path(session_id): Path<String>,
76 State(state): State<AppState>,
77) -> impl IntoResponse {
78 let subscriber_count = state.broadcaster.receiver_count();
79 Json(serde_json::json!({
80 "status": "running",
81 "session_id": session_id,
82 "subscribers": subscriber_count,
83 }))
84}
85
86async fn events_handler(
88 Path(_session_id): Path<String>,
89 State(state): State<AppState>,
90) -> Sse<impl futures_core::Stream<Item = Result<sse::Event, std::convert::Infallible>>> {
91 let mut rx: Receiver<String> = state.broadcaster.subscribe();
92 let stream = async_stream::stream! {
93 while let Ok(msg) = rx.recv().await {
94 yield Ok(sse::Event::default().data(msg));
95 }
96 };
97 Sse::new(stream)
98}
99
100pub fn router(state: AppState) -> Router {
102 Router::new()
103 .route("/run", post(run_handler))
104 .route("/status/:session_id", get(status_handler))
105 .route("/events/:session_id", get(events_handler))
106 .with_state(state)
107}
108
109pub async fn start_server(state: AppState, addr: SocketAddr) -> Result<(), AgentError> {
111 let app = router(state);
112 let listener = TcpListener::bind(addr)
113 .await
114 .map_err(|e| AgentError::Other(e.to_string()))?;
115 axum::serve(listener, app.into_make_service())
116 .await
117 .map_err(|e| AgentError::Other(e.to_string()))
118}