1pub mod bridge;
42pub mod processor;
43pub mod routes;
44
45use std::net::SocketAddr;
46use std::sync::Arc;
47
48use axum::{
49 Router,
50 routing::{get, post},
51};
52use syncable_ag_ui_core::{Event, JsonValue, RunId, ThreadId};
53use tokio::sync::{RwLock, broadcast, mpsc};
54use tower_http::cors::{Any, CorsLayer};
55
56pub use bridge::EventBridge;
57pub use processor::{AgentProcessor, ProcessorConfig, ThreadSession};
58
59pub use syncable_ag_ui_core::types::{Context, Message as AgUiMessage, RunAgentInput, Tool};
61
62#[derive(Debug, Clone)]
65pub struct AgentMessage {
66 pub input: RunAgentInput,
68}
69
70impl AgentMessage {
71 pub fn new(input: RunAgentInput) -> Self {
73 Self { input }
74 }
75}
76
77#[derive(Debug, Clone)]
79pub struct AgUiConfig {
80 pub port: u16,
82 pub host: String,
84 pub max_connections: usize,
86 pub enable_processor: bool,
88 pub processor_config: Option<ProcessorConfig>,
90}
91
92impl Default for AgUiConfig {
93 fn default() -> Self {
94 Self {
95 port: 9090,
96 host: "127.0.0.1".to_string(),
97 max_connections: 100,
98 enable_processor: false,
99 processor_config: None,
100 }
101 }
102}
103
104impl AgUiConfig {
105 pub fn new() -> Self {
107 Self::default()
108 }
109
110 pub fn port(mut self, port: u16) -> Self {
112 self.port = port;
113 self
114 }
115
116 pub fn host(mut self, host: impl Into<String>) -> Self {
118 self.host = host.into();
119 self
120 }
121
122 pub fn with_processor(mut self, enable: bool) -> Self {
127 self.enable_processor = enable;
128 if enable && self.processor_config.is_none() {
129 self.processor_config = Some(ProcessorConfig::default());
130 }
131 self
132 }
133
134 pub fn with_processor_config(mut self, config: ProcessorConfig) -> Self {
136 self.processor_config = Some(config);
137 self.enable_processor = true;
138 self
139 }
140}
141
142#[derive(Clone)]
144pub struct ServerState {
145 event_tx: broadcast::Sender<Event<JsonValue>>,
147 message_tx: mpsc::Sender<AgentMessage>,
149 message_rx: Arc<RwLock<Option<mpsc::Receiver<AgentMessage>>>>,
151 thread_id: Arc<RwLock<ThreadId>>,
153 run_id: Arc<RwLock<Option<RunId>>>,
155}
156
157impl ServerState {
158 pub fn new() -> Self {
160 let (event_tx, _) = broadcast::channel(1000);
161 let (message_tx, message_rx) = mpsc::channel(100);
162 Self {
163 event_tx,
164 message_tx,
165 message_rx: Arc::new(RwLock::new(Some(message_rx))),
166 thread_id: Arc::new(RwLock::new(ThreadId::random())),
167 run_id: Arc::new(RwLock::new(None)),
168 }
169 }
170
171 pub fn event_sender(&self) -> EventBridge {
173 EventBridge::new(
174 self.event_tx.clone(),
175 Arc::clone(&self.thread_id),
176 Arc::clone(&self.run_id),
177 )
178 }
179
180 pub fn subscribe(&self) -> broadcast::Receiver<Event<JsonValue>> {
182 self.event_tx.subscribe()
183 }
184
185 pub fn message_sender(&self) -> mpsc::Sender<AgentMessage> {
187 self.message_tx.clone()
188 }
189
190 pub async fn take_message_receiver(&self) -> Option<mpsc::Receiver<AgentMessage>> {
195 self.message_rx.write().await.take()
196 }
197}
198
199impl Default for ServerState {
200 fn default() -> Self {
201 Self::new()
202 }
203}
204
205pub struct AgUiServer {
207 config: AgUiConfig,
208 state: ServerState,
209}
210
211impl AgUiServer {
212 pub fn new(config: AgUiConfig) -> Self {
214 Self {
215 config,
216 state: ServerState::new(),
217 }
218 }
219
220 pub fn with_defaults() -> Self {
222 Self::new(AgUiConfig::default())
223 }
224
225 pub fn event_bridge(&self) -> EventBridge {
227 self.state.event_sender()
228 }
229
230 pub fn state(&self) -> ServerState {
232 self.state.clone()
233 }
234
235 pub async fn run(self) -> Result<(), std::io::Error> {
240 let addr: SocketAddr = format!("{}:{}", self.config.host, self.config.port)
241 .parse()
242 .expect("Invalid address");
243
244 if self.config.enable_processor {
246 let processor_config = self.config.processor_config.clone().unwrap_or_default();
247
248 if let Some(msg_rx) = self.state.take_message_receiver().await {
249 let event_bridge = self.state.event_sender();
250 let mut processor = AgentProcessor::new(msg_rx, event_bridge, processor_config);
251
252 tokio::spawn(async move {
254 processor.run().await;
255 });
256
257 println!("Agent processor started");
258 }
259 }
260
261 let cors = CorsLayer::new()
263 .allow_origin(Any)
264 .allow_methods(Any)
265 .allow_headers(Any);
266
267 let app = Router::new()
268 .route("/", get(routes::health).post(routes::post_message))
269 .route("/info", get(routes::info))
270 .route("/sse", get(routes::sse_handler))
271 .route("/ws", get(routes::ws_handler))
272 .route("/message", post(routes::post_message))
273 .route("/health", get(routes::health))
274 .layer(cors)
275 .with_state(self.state);
276
277 println!("AG-UI server listening on http://{}", addr);
278
279 let listener = tokio::net::TcpListener::bind(addr).await?;
280 axum::serve(listener, app).await
281 }
282
283 pub fn addr(&self) -> String {
285 format!("{}:{}", self.config.host, self.config.port)
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292
293 #[test]
294 fn test_config_default() {
295 let config = AgUiConfig::default();
296 assert_eq!(config.port, 9090);
297 assert_eq!(config.host, "127.0.0.1");
298 }
299
300 #[test]
301 fn test_config_builder() {
302 let config = AgUiConfig::new().port(8080).host("0.0.0.0");
303 assert_eq!(config.port, 8080);
304 assert_eq!(config.host, "0.0.0.0");
305 }
306
307 #[test]
308 fn test_server_state_new() {
309 let state = ServerState::new();
310 let _bridge = state.event_sender();
311 let _rx = state.subscribe();
312 }
313
314 #[test]
315 fn test_server_addr() {
316 let server = AgUiServer::with_defaults();
317 assert_eq!(server.addr(), "127.0.0.1:9090");
318 }
319
320 #[test]
321 fn test_event_bridge_from_state() {
322 let state = ServerState::new();
323 let bridge1 = state.event_sender();
324 let bridge2 = state.event_sender();
325
326 let _ = state.subscribe();
329
330 drop(bridge1);
332 drop(bridge2);
333 }
334
335 #[tokio::test]
336 async fn test_server_event_flow() {
337 use syncable_ag_ui_core::Event;
338
339 let state = ServerState::new();
340 let bridge = state.event_sender();
341 let mut rx = state.subscribe();
342
343 bridge.start_run().await;
345
346 let event = rx.recv().await.expect("Should receive RunStarted");
348 assert!(matches!(event, Event::RunStarted(_)));
349 }
350
351 #[tokio::test]
352 async fn test_message_channel() {
353 use syncable_ag_ui_core::types::{Message, RunAgentInput};
354
355 let state = ServerState::new();
356 let msg_tx = state.message_sender();
357 let mut msg_rx = state
358 .take_message_receiver()
359 .await
360 .expect("Should get receiver");
361
362 let input = RunAgentInput::new(ThreadId::random(), RunId::random())
364 .with_messages(vec![Message::new_user("Hello agent")]);
365
366 let agent_msg = AgentMessage::new(input);
368 msg_tx.send(agent_msg).await.expect("Should send");
369
370 let received = msg_rx.recv().await.expect("Should receive message");
372 assert_eq!(received.input.messages.len(), 1);
373 }
374
375 #[tokio::test]
376 async fn test_message_receiver_only_once() {
377 let state = ServerState::new();
378
379 let rx1 = state.take_message_receiver().await;
381 assert!(rx1.is_some());
382
383 let rx2 = state.take_message_receiver().await;
385 assert!(rx2.is_none());
386 }
387
388 #[test]
389 fn test_config_with_processor() {
390 let config = AgUiConfig::new().with_processor(true);
391 assert!(config.enable_processor);
392 assert!(config.processor_config.is_some());
393 }
394
395 #[test]
396 fn test_config_with_processor_config() {
397 let processor_config = ProcessorConfig::new()
398 .with_provider("anthropic")
399 .with_model("claude-3-sonnet");
400
401 let config = AgUiConfig::new().with_processor_config(processor_config);
402
403 assert!(config.enable_processor);
404 let proc_config = config.processor_config.unwrap();
405 assert_eq!(proc_config.provider, "anthropic");
406 assert_eq!(proc_config.model, "claude-3-sonnet");
407 }
408
409 #[tokio::test]
410 async fn test_processor_integration_with_state() {
411 use syncable_ag_ui_core::Event;
412 use syncable_ag_ui_core::types::{Message, RunAgentInput};
413
414 let state = ServerState::new();
416 let msg_tx = state.message_sender();
417 let mut event_rx = state.subscribe();
418 let msg_rx = state
419 .take_message_receiver()
420 .await
421 .expect("Should get receiver");
422
423 let event_bridge = state.event_sender();
425 let mut processor = AgentProcessor::with_defaults(msg_rx, event_bridge);
426
427 let handle = tokio::spawn(async move {
428 processor.run().await;
429 });
430
431 let thread_id = ThreadId::random();
433 let run_id = RunId::random();
434 let input = RunAgentInput::new(thread_id.clone(), run_id.clone())
435 .with_messages(vec![Message::new_user("Integration test message")]);
436
437 msg_tx
438 .send(AgentMessage::new(input))
439 .await
440 .expect("Should send");
441
442 let event = tokio::time::timeout(std::time::Duration::from_millis(200), event_rx.recv())
444 .await
445 .expect("Should receive in time")
446 .expect("Should have event");
447
448 assert!(matches!(event, Event::RunStarted(_)));
449
450 drop(msg_tx);
452
453 let _ = tokio::time::timeout(std::time::Duration::from_millis(200), handle).await;
455 }
456
457 async fn collect_until_finished(
463 rx: &mut tokio::sync::broadcast::Receiver<syncable_ag_ui_core::Event>,
464 ) -> Vec<syncable_ag_ui_core::Event> {
465 use syncable_ag_ui_core::Event;
466 let mut events = Vec::new();
467 loop {
468 match tokio::time::timeout(std::time::Duration::from_secs(5), rx.recv()).await {
469 Ok(Ok(event)) => {
470 let is_finished = matches!(&event, Event::RunFinished(_) | Event::RunError(_));
471 events.push(event);
472 if is_finished {
473 break;
474 }
475 }
476 _ => break,
477 }
478 }
479 events
480 }
481
482 async fn drain_events_until_run_finished(
484 rx: &mut tokio::sync::broadcast::Receiver<syncable_ag_ui_core::Event>,
485 ) {
486 use syncable_ag_ui_core::Event;
487 loop {
488 match tokio::time::timeout(std::time::Duration::from_secs(30), rx.recv()).await {
489 Ok(Ok(Event::RunFinished(_))) => break,
490 Ok(Ok(Event::RunError(_))) => break,
491 Ok(Ok(_)) => continue,
492 _ => panic!("Timeout or error waiting for RunFinished"),
493 }
494 }
495 }
496
497 #[tokio::test]
498 async fn test_multi_turn_conversation() {
499 use syncable_ag_ui_core::types::{Message, RunAgentInput};
500
501 let state = ServerState::new();
503 let msg_tx = state.message_sender();
504 let mut event_rx = state.subscribe();
505 let msg_rx = state
506 .take_message_receiver()
507 .await
508 .expect("Should get receiver");
509
510 let event_bridge = state.event_sender();
512 let mut processor = AgentProcessor::with_defaults(msg_rx, event_bridge);
513
514 let handle = tokio::spawn(async move {
515 processor.run().await;
516 });
517
518 let thread_id = ThreadId::random();
519
520 let input1 = RunAgentInput::new(thread_id.clone(), RunId::random())
522 .with_messages(vec![Message::new_user("Hello")]);
523 msg_tx
524 .send(AgentMessage::new(input1))
525 .await
526 .expect("Should send");
527
528 drain_events_until_run_finished(&mut event_rx).await;
530
531 let input2 = RunAgentInput::new(thread_id.clone(), RunId::random())
533 .with_messages(vec![Message::new_user("Follow up message")]);
534 msg_tx
535 .send(AgentMessage::new(input2))
536 .await
537 .expect("Should send");
538
539 drain_events_until_run_finished(&mut event_rx).await;
541
542 drop(msg_tx);
543 let _ = tokio::time::timeout(std::time::Duration::from_millis(200), handle).await;
544 }
545
546 #[tokio::test]
547 async fn test_event_sequence() {
548 use syncable_ag_ui_core::Event;
549 use syncable_ag_ui_core::types::{Message, RunAgentInput};
550
551 let state = ServerState::new();
553 let msg_tx = state.message_sender();
554 let mut event_rx = state.subscribe();
555 let msg_rx = state.take_message_receiver().await.expect("receiver");
556 let event_bridge = state.event_sender();
557 let mut processor = AgentProcessor::with_defaults(msg_rx, event_bridge);
558
559 tokio::spawn(async move {
560 processor.run().await;
561 });
562
563 let thread_id = ThreadId::random();
565 let input = RunAgentInput::new(thread_id, RunId::random())
566 .with_messages(vec![Message::new_user("Test event sequence")]);
567 msg_tx.send(AgentMessage::new(input)).await.unwrap();
568
569 let events = collect_until_finished(&mut event_rx).await;
571
572 assert!(!events.is_empty(), "Should receive at least one event");
574 assert!(
575 matches!(events[0], Event::RunStarted(_)),
576 "First event should be RunStarted"
577 );
578
579 assert!(
581 matches!(
582 events.last(),
583 Some(Event::RunFinished(_) | Event::RunError(_))
584 ),
585 "Last event should be RunFinished or RunError"
586 );
587
588 assert!(
593 events.len() >= 2,
594 "Should have at least RunStarted and terminal event"
595 );
596
597 drop(msg_tx);
598 }
599
600 #[tokio::test]
601 async fn test_empty_message_error() {
602 use syncable_ag_ui_core::Event;
603 use syncable_ag_ui_core::types::RunAgentInput;
604
605 let state = ServerState::new();
606 let msg_tx = state.message_sender();
607 let mut event_rx = state.subscribe();
608 let msg_rx = state.take_message_receiver().await.expect("receiver");
609 let event_bridge = state.event_sender();
610 let mut processor = AgentProcessor::with_defaults(msg_rx, event_bridge);
611
612 tokio::spawn(async move {
613 processor.run().await;
614 });
615
616 let input = RunAgentInput::new(ThreadId::random(), RunId::random());
618 msg_tx.send(AgentMessage::new(input)).await.unwrap();
619
620 let events = collect_until_finished(&mut event_rx).await;
622
623 assert!(
625 matches!(events[0], Event::RunStarted(_)),
626 "First should be RunStarted"
627 );
628 assert!(
629 matches!(events.last(), Some(Event::RunError(_))),
630 "Should end with RunError for empty message"
631 );
632
633 drop(msg_tx);
634 }
635
636 #[tokio::test]
637 async fn test_invalid_provider_error() {
638 use syncable_ag_ui_core::Event;
639 use syncable_ag_ui_core::types::{Message, RunAgentInput};
640
641 let state = ServerState::new();
642 let msg_tx = state.message_sender();
643 let mut event_rx = state.subscribe();
644 let msg_rx = state.take_message_receiver().await.expect("receiver");
645 let event_bridge = state.event_sender();
646
647 let config = ProcessorConfig::new().with_provider("invalid_provider_xyz");
649 let mut processor = AgentProcessor::new(msg_rx, event_bridge, config);
650
651 tokio::spawn(async move {
652 processor.run().await;
653 });
654
655 let input = RunAgentInput::new(ThreadId::random(), RunId::random())
656 .with_messages(vec![Message::new_user("Test invalid provider")]);
657 msg_tx.send(AgentMessage::new(input)).await.unwrap();
658
659 let events = collect_until_finished(&mut event_rx).await;
661
662 assert!(
664 matches!(events.last(), Some(Event::RunError(_))),
665 "Should end with RunError for invalid provider"
666 );
667
668 drop(msg_tx);
669 }
670
671 #[tokio::test]
672 async fn test_custom_system_prompt() {
673 use syncable_ag_ui_core::types::{Message, RunAgentInput};
674
675 let state = ServerState::new();
676 let msg_tx = state.message_sender();
677 let mut event_rx = state.subscribe();
678 let msg_rx = state.take_message_receiver().await.expect("receiver");
679 let event_bridge = state.event_sender();
680
681 let config = ProcessorConfig::new().with_system_prompt(
683 "You are a DevOps assistant. Always respond with deployment advice.",
684 );
685 let mut processor = AgentProcessor::new(msg_rx, event_bridge, config);
686
687 tokio::spawn(async move {
688 processor.run().await;
689 });
690
691 let input = RunAgentInput::new(ThreadId::random(), RunId::random())
692 .with_messages(vec![Message::new_user("Hello")]);
693 msg_tx.send(AgentMessage::new(input)).await.unwrap();
694
695 drain_events_until_run_finished(&mut event_rx).await;
697
698 drop(msg_tx);
699 }
700}