pub mod bridge;
pub mod processor;
pub mod routes;
use std::net::SocketAddr;
use std::sync::Arc;
use axum::{
Router,
routing::{get, post},
};
use syncable_ag_ui_core::{Event, JsonValue, RunId, ThreadId};
use tokio::sync::{RwLock, broadcast, mpsc};
use tower_http::cors::{Any, CorsLayer};
pub use bridge::EventBridge;
pub use processor::{AgentProcessor, ProcessorConfig, ThreadSession};
pub use syncable_ag_ui_core::types::{Context, Message as AgUiMessage, RunAgentInput, Tool};
#[derive(Debug, Clone)]
pub struct AgentMessage {
pub input: RunAgentInput,
}
impl AgentMessage {
pub fn new(input: RunAgentInput) -> Self {
Self { input }
}
}
#[derive(Debug, Clone)]
pub struct AgUiConfig {
pub port: u16,
pub host: String,
pub max_connections: usize,
pub enable_processor: bool,
pub processor_config: Option<ProcessorConfig>,
}
impl Default for AgUiConfig {
fn default() -> Self {
Self {
port: 9090,
host: "127.0.0.1".to_string(),
max_connections: 100,
enable_processor: false,
processor_config: None,
}
}
}
impl AgUiConfig {
pub fn new() -> Self {
Self::default()
}
pub fn port(mut self, port: u16) -> Self {
self.port = port;
self
}
pub fn host(mut self, host: impl Into<String>) -> Self {
self.host = host.into();
self
}
pub fn with_processor(mut self, enable: bool) -> Self {
self.enable_processor = enable;
if enable && self.processor_config.is_none() {
self.processor_config = Some(ProcessorConfig::default());
}
self
}
pub fn with_processor_config(mut self, config: ProcessorConfig) -> Self {
self.processor_config = Some(config);
self.enable_processor = true;
self
}
}
#[derive(Clone)]
pub struct ServerState {
event_tx: broadcast::Sender<Event<JsonValue>>,
message_tx: mpsc::Sender<AgentMessage>,
message_rx: Arc<RwLock<Option<mpsc::Receiver<AgentMessage>>>>,
thread_id: Arc<RwLock<ThreadId>>,
run_id: Arc<RwLock<Option<RunId>>>,
}
impl ServerState {
pub fn new() -> Self {
let (event_tx, _) = broadcast::channel(1000);
let (message_tx, message_rx) = mpsc::channel(100);
Self {
event_tx,
message_tx,
message_rx: Arc::new(RwLock::new(Some(message_rx))),
thread_id: Arc::new(RwLock::new(ThreadId::random())),
run_id: Arc::new(RwLock::new(None)),
}
}
pub fn event_sender(&self) -> EventBridge {
EventBridge::new(
self.event_tx.clone(),
Arc::clone(&self.thread_id),
Arc::clone(&self.run_id),
)
}
pub fn subscribe(&self) -> broadcast::Receiver<Event<JsonValue>> {
self.event_tx.subscribe()
}
pub fn message_sender(&self) -> mpsc::Sender<AgentMessage> {
self.message_tx.clone()
}
pub async fn take_message_receiver(&self) -> Option<mpsc::Receiver<AgentMessage>> {
self.message_rx.write().await.take()
}
}
impl Default for ServerState {
fn default() -> Self {
Self::new()
}
}
pub struct AgUiServer {
config: AgUiConfig,
state: ServerState,
}
impl AgUiServer {
pub fn new(config: AgUiConfig) -> Self {
Self {
config,
state: ServerState::new(),
}
}
pub fn with_defaults() -> Self {
Self::new(AgUiConfig::default())
}
pub fn event_bridge(&self) -> EventBridge {
self.state.event_sender()
}
pub fn state(&self) -> ServerState {
self.state.clone()
}
pub async fn run(self) -> Result<(), std::io::Error> {
let addr: SocketAddr = format!("{}:{}", self.config.host, self.config.port)
.parse()
.expect("Invalid address");
if self.config.enable_processor {
let processor_config = self.config.processor_config.clone().unwrap_or_default();
if let Some(msg_rx) = self.state.take_message_receiver().await {
let event_bridge = self.state.event_sender();
let mut processor = AgentProcessor::new(msg_rx, event_bridge, processor_config);
tokio::spawn(async move {
processor.run().await;
});
println!("Agent processor started");
}
}
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
let app = Router::new()
.route("/", get(routes::health).post(routes::post_message))
.route("/info", get(routes::info))
.route("/sse", get(routes::sse_handler))
.route("/ws", get(routes::ws_handler))
.route("/message", post(routes::post_message))
.route("/health", get(routes::health))
.layer(cors)
.with_state(self.state);
println!("AG-UI server listening on http://{}", addr);
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, app).await
}
pub fn addr(&self) -> String {
format!("{}:{}", self.config.host, self.config.port)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = AgUiConfig::default();
assert_eq!(config.port, 9090);
assert_eq!(config.host, "127.0.0.1");
}
#[test]
fn test_config_builder() {
let config = AgUiConfig::new().port(8080).host("0.0.0.0");
assert_eq!(config.port, 8080);
assert_eq!(config.host, "0.0.0.0");
}
#[test]
fn test_server_state_new() {
let state = ServerState::new();
let _bridge = state.event_sender();
let _rx = state.subscribe();
}
#[test]
fn test_server_addr() {
let server = AgUiServer::with_defaults();
assert_eq!(server.addr(), "127.0.0.1:9090");
}
#[test]
fn test_event_bridge_from_state() {
let state = ServerState::new();
let bridge1 = state.event_sender();
let bridge2 = state.event_sender();
let _ = state.subscribe();
drop(bridge1);
drop(bridge2);
}
#[tokio::test]
async fn test_server_event_flow() {
use syncable_ag_ui_core::Event;
let state = ServerState::new();
let bridge = state.event_sender();
let mut rx = state.subscribe();
bridge.start_run().await;
let event = rx.recv().await.expect("Should receive RunStarted");
assert!(matches!(event, Event::RunStarted(_)));
}
#[tokio::test]
async fn test_message_channel() {
use syncable_ag_ui_core::types::{Message, RunAgentInput};
let state = ServerState::new();
let msg_tx = state.message_sender();
let mut msg_rx = state
.take_message_receiver()
.await
.expect("Should get receiver");
let input = RunAgentInput::new(ThreadId::random(), RunId::random())
.with_messages(vec![Message::new_user("Hello agent")]);
let agent_msg = AgentMessage::new(input);
msg_tx.send(agent_msg).await.expect("Should send");
let received = msg_rx.recv().await.expect("Should receive message");
assert_eq!(received.input.messages.len(), 1);
}
#[tokio::test]
async fn test_message_receiver_only_once() {
let state = ServerState::new();
let rx1 = state.take_message_receiver().await;
assert!(rx1.is_some());
let rx2 = state.take_message_receiver().await;
assert!(rx2.is_none());
}
#[test]
fn test_config_with_processor() {
let config = AgUiConfig::new().with_processor(true);
assert!(config.enable_processor);
assert!(config.processor_config.is_some());
}
#[test]
fn test_config_with_processor_config() {
let processor_config = ProcessorConfig::new()
.with_provider("anthropic")
.with_model("claude-3-sonnet");
let config = AgUiConfig::new().with_processor_config(processor_config);
assert!(config.enable_processor);
let proc_config = config.processor_config.unwrap();
assert_eq!(proc_config.provider, "anthropic");
assert_eq!(proc_config.model, "claude-3-sonnet");
}
#[tokio::test]
async fn test_processor_integration_with_state() {
use syncable_ag_ui_core::Event;
use syncable_ag_ui_core::types::{Message, RunAgentInput};
let state = ServerState::new();
let msg_tx = state.message_sender();
let mut event_rx = state.subscribe();
let msg_rx = state
.take_message_receiver()
.await
.expect("Should get receiver");
let event_bridge = state.event_sender();
let mut processor = AgentProcessor::with_defaults(msg_rx, event_bridge);
let handle = tokio::spawn(async move {
processor.run().await;
});
let thread_id = ThreadId::random();
let run_id = RunId::random();
let input = RunAgentInput::new(thread_id.clone(), run_id.clone())
.with_messages(vec![Message::new_user("Integration test message")]);
msg_tx
.send(AgentMessage::new(input))
.await
.expect("Should send");
let event = tokio::time::timeout(std::time::Duration::from_millis(200), event_rx.recv())
.await
.expect("Should receive in time")
.expect("Should have event");
assert!(matches!(event, Event::RunStarted(_)));
drop(msg_tx);
let _ = tokio::time::timeout(std::time::Duration::from_millis(200), handle).await;
}
async fn collect_until_finished(
rx: &mut tokio::sync::broadcast::Receiver<syncable_ag_ui_core::Event>,
) -> Vec<syncable_ag_ui_core::Event> {
use syncable_ag_ui_core::Event;
let mut events = Vec::new();
loop {
match tokio::time::timeout(std::time::Duration::from_secs(5), rx.recv()).await {
Ok(Ok(event)) => {
let is_finished = matches!(&event, Event::RunFinished(_) | Event::RunError(_));
events.push(event);
if is_finished {
break;
}
}
_ => break,
}
}
events
}
async fn drain_events_until_run_finished(
rx: &mut tokio::sync::broadcast::Receiver<syncable_ag_ui_core::Event>,
) {
use syncable_ag_ui_core::Event;
loop {
match tokio::time::timeout(std::time::Duration::from_secs(30), rx.recv()).await {
Ok(Ok(Event::RunFinished(_))) => break,
Ok(Ok(Event::RunError(_))) => break,
Ok(Ok(_)) => continue,
_ => panic!("Timeout or error waiting for RunFinished"),
}
}
}
#[tokio::test]
async fn test_multi_turn_conversation() {
use syncable_ag_ui_core::types::{Message, RunAgentInput};
let state = ServerState::new();
let msg_tx = state.message_sender();
let mut event_rx = state.subscribe();
let msg_rx = state
.take_message_receiver()
.await
.expect("Should get receiver");
let event_bridge = state.event_sender();
let mut processor = AgentProcessor::with_defaults(msg_rx, event_bridge);
let handle = tokio::spawn(async move {
processor.run().await;
});
let thread_id = ThreadId::random();
let input1 = RunAgentInput::new(thread_id.clone(), RunId::random())
.with_messages(vec![Message::new_user("Hello")]);
msg_tx
.send(AgentMessage::new(input1))
.await
.expect("Should send");
drain_events_until_run_finished(&mut event_rx).await;
let input2 = RunAgentInput::new(thread_id.clone(), RunId::random())
.with_messages(vec![Message::new_user("Follow up message")]);
msg_tx
.send(AgentMessage::new(input2))
.await
.expect("Should send");
drain_events_until_run_finished(&mut event_rx).await;
drop(msg_tx);
let _ = tokio::time::timeout(std::time::Duration::from_millis(200), handle).await;
}
#[tokio::test]
async fn test_event_sequence() {
use syncable_ag_ui_core::Event;
use syncable_ag_ui_core::types::{Message, RunAgentInput};
let state = ServerState::new();
let msg_tx = state.message_sender();
let mut event_rx = state.subscribe();
let msg_rx = state.take_message_receiver().await.expect("receiver");
let event_bridge = state.event_sender();
let mut processor = AgentProcessor::with_defaults(msg_rx, event_bridge);
tokio::spawn(async move {
processor.run().await;
});
let thread_id = ThreadId::random();
let input = RunAgentInput::new(thread_id, RunId::random())
.with_messages(vec![Message::new_user("Test event sequence")]);
msg_tx.send(AgentMessage::new(input)).await.unwrap();
let events = collect_until_finished(&mut event_rx).await;
assert!(!events.is_empty(), "Should receive at least one event");
assert!(
matches!(events[0], Event::RunStarted(_)),
"First event should be RunStarted"
);
assert!(
matches!(
events.last(),
Some(Event::RunFinished(_) | Event::RunError(_))
),
"Last event should be RunFinished or RunError"
);
assert!(
events.len() >= 2,
"Should have at least RunStarted and terminal event"
);
drop(msg_tx);
}
#[tokio::test]
async fn test_empty_message_error() {
use syncable_ag_ui_core::Event;
use syncable_ag_ui_core::types::RunAgentInput;
let state = ServerState::new();
let msg_tx = state.message_sender();
let mut event_rx = state.subscribe();
let msg_rx = state.take_message_receiver().await.expect("receiver");
let event_bridge = state.event_sender();
let mut processor = AgentProcessor::with_defaults(msg_rx, event_bridge);
tokio::spawn(async move {
processor.run().await;
});
let input = RunAgentInput::new(ThreadId::random(), RunId::random());
msg_tx.send(AgentMessage::new(input)).await.unwrap();
let events = collect_until_finished(&mut event_rx).await;
assert!(
matches!(events[0], Event::RunStarted(_)),
"First should be RunStarted"
);
assert!(
matches!(events.last(), Some(Event::RunError(_))),
"Should end with RunError for empty message"
);
drop(msg_tx);
}
#[tokio::test]
async fn test_invalid_provider_error() {
use syncable_ag_ui_core::Event;
use syncable_ag_ui_core::types::{Message, RunAgentInput};
let state = ServerState::new();
let msg_tx = state.message_sender();
let mut event_rx = state.subscribe();
let msg_rx = state.take_message_receiver().await.expect("receiver");
let event_bridge = state.event_sender();
let config = ProcessorConfig::new().with_provider("invalid_provider_xyz");
let mut processor = AgentProcessor::new(msg_rx, event_bridge, config);
tokio::spawn(async move {
processor.run().await;
});
let input = RunAgentInput::new(ThreadId::random(), RunId::random())
.with_messages(vec![Message::new_user("Test invalid provider")]);
msg_tx.send(AgentMessage::new(input)).await.unwrap();
let events = collect_until_finished(&mut event_rx).await;
assert!(
matches!(events.last(), Some(Event::RunError(_))),
"Should end with RunError for invalid provider"
);
drop(msg_tx);
}
#[tokio::test]
async fn test_custom_system_prompt() {
use syncable_ag_ui_core::types::{Message, RunAgentInput};
let state = ServerState::new();
let msg_tx = state.message_sender();
let mut event_rx = state.subscribe();
let msg_rx = state.take_message_receiver().await.expect("receiver");
let event_bridge = state.event_sender();
let config = ProcessorConfig::new().with_system_prompt(
"You are a DevOps assistant. Always respond with deployment advice.",
);
let mut processor = AgentProcessor::new(msg_rx, event_bridge, config);
tokio::spawn(async move {
processor.run().await;
});
let input = RunAgentInput::new(ThreadId::random(), RunId::random())
.with_messages(vec![Message::new_user("Hello")]);
msg_tx.send(AgentMessage::new(input)).await.unwrap();
drain_events_until_run_finished(&mut event_rx).await;
drop(msg_tx);
}
}