sayr_engine/
server.rs

1use std::collections::HashMap;
2use std::convert::Infallible;
3use std::net::SocketAddr;
4use std::sync::Arc;
5
6use axum::extract::{Path, Query, State};
7use axum::http::{HeaderMap, StatusCode};
8use axum::response::sse::{Event, Sse};
9use axum::response::{Html, IntoResponse, Response};
10use axum::routing::{get, post};
11use axum::Json;
12use axum::Router;
13use futures::stream::Stream;
14use futures::StreamExt;
15use serde::Deserialize;
16use serde::Serialize;
17use serde_json::{json, Value};
18use tokio::sync::{broadcast, Mutex, RwLock};
19use tokio_stream::wrappers::BroadcastStream;
20
21use crate::message::Message;
22use crate::{
23    AccessController, Action, GovernanceRole, LanguageModel, Principal, Result, SecurityConfig,
24    Team, TelemetryCollector, Workflow,
25};
26
27pub struct AgentRuntime<M: LanguageModel + 'static> {
28    pub teams: Arc<RwLock<HashMap<String, Team<M>>>>,
29    pub workflows: Arc<RwLock<HashMap<String, Arc<Workflow>>>>,
30    pub agents: Arc<RwLock<HashMap<String, Arc<Mutex<crate::Agent<M>>>>>>,
31    pub events: broadcast::Sender<String>,
32    trace_events: broadcast::Sender<TraceEvent>,
33    security: SecurityConfig,
34    access_control: AccessController,
35    telemetry: TelemetryCollector,
36    metrics: crate::MetricsTracker,
37}
38
39impl<M: LanguageModel + 'static> Clone for AgentRuntime<M> {
40    fn clone(&self) -> Self {
41        Self {
42            teams: Arc::clone(&self.teams),
43            workflows: Arc::clone(&self.workflows),
44            agents: Arc::clone(&self.agents),
45            events: self.events.clone(),
46            trace_events: self.trace_events.clone(),
47            security: self.security.clone(),
48            access_control: self.access_control.clone(),
49            telemetry: self.telemetry.clone(),
50            metrics: self.metrics.clone(),
51        }
52    }
53}
54
55impl<M: LanguageModel + 'static> AgentRuntime<M> {
56    pub fn new() -> Self {
57        Self::with_security(SecurityConfig::default())
58    }
59
60    pub fn with_security(security: SecurityConfig) -> Self {
61        let (tx, _) = broadcast::channel(512);
62        let (trace_tx, _) = broadcast::channel::<TraceEvent>(256);
63        let access_control = AccessController::new();
64        access_control.allow(GovernanceRole::User, Action::SendMessage);
65        access_control.allow(GovernanceRole::Service, Action::SendMessage);
66        access_control.allow(GovernanceRole::Admin, Action::SendMessage);
67        access_control.allow(GovernanceRole::User, Action::ReadTranscript);
68        access_control.allow(GovernanceRole::Service, Action::ReadTranscript);
69        Self {
70            teams: Arc::new(RwLock::new(HashMap::new())),
71            workflows: Arc::new(RwLock::new(HashMap::new())),
72            agents: Arc::new(RwLock::new(HashMap::new())),
73            events: tx,
74            trace_events: trace_tx,
75            security,
76            access_control,
77            telemetry: TelemetryCollector::default(),
78            metrics: crate::MetricsTracker::default(),
79        }
80    }
81
82    pub async fn register_team(&self, name: impl Into<String>, team: Team<M>) {
83        self.teams.write().await.insert(name.into(), team);
84    }
85
86    pub async fn register_workflow(&self, flow: Workflow) {
87        self.workflows
88            .write()
89            .await
90            .insert(flow.name.clone(), Arc::new(flow));
91    }
92
93    pub async fn register_agent(&self, name: impl Into<String>, mut agent: crate::Agent<M>) {
94        let name = name.into();
95        let controller = Arc::new(self.access_control.clone());
96        agent.attach_access_control(controller);
97        agent.attach_metrics(self.metrics.clone());
98        agent.attach_telemetry(self.telemetry.clone());
99        for tool in agent.tool_names() {
100            self.access_control
101                .allow(GovernanceRole::User, Action::CallTool(tool.clone()));
102            self.access_control
103                .allow(GovernanceRole::Admin, Action::CallTool(tool.clone()));
104            self.access_control
105                .allow(GovernanceRole::Service, Action::CallTool(tool.clone()));
106        }
107        self.agents
108            .write()
109            .await
110            .insert(name, Arc::new(Mutex::new(agent)));
111    }
112
113    fn resolve_tenant(
114        &self,
115        headers: &HeaderMap,
116        override_tenant: &Option<String>,
117    ) -> std::result::Result<Option<String>, Response> {
118        let tenant = override_tenant.clone().or_else(|| {
119            headers
120                .get("x-tenant")
121                .and_then(|h| h.to_str().ok().map(|v| v.to_string()))
122        });
123
124        if !self.security.allowed_tenants.is_empty() {
125            if let Some(ref t) = tenant {
126                if !self
127                    .security
128                    .allowed_tenants
129                    .iter()
130                    .any(|allowed| allowed == t)
131                {
132                    return Err(json_error(
133                        StatusCode::UNAUTHORIZED,
134                        "tenant not allowed for this deployment",
135                    ));
136                }
137            } else {
138                return Err(json_error(
139                    StatusCode::UNAUTHORIZED,
140                    "tenant is required for this deployment",
141                ));
142            }
143        }
144
145        Ok(tenant)
146    }
147
148    fn build_principal(
149        &self,
150        headers: &HeaderMap,
151        body: &AgentChatRequest,
152    ) -> std::result::Result<Principal, Response> {
153        let tenant = self.resolve_tenant(headers, &body.tenant)?;
154        let role = parse_role(
155            headers
156                .get("x-principal-role")
157                .and_then(|h| h.to_str().ok())
158                .or_else(|| body.role.as_deref()),
159        );
160        let principal_id = headers
161            .get("x-principal-id")
162            .and_then(|h| h.to_str().ok())
163            .map(|s| s.to_string())
164            .or_else(|| body.principal_id.clone())
165            .unwrap_or_else(|| "anonymous".into());
166
167        Ok(Principal {
168            id: principal_id,
169            role,
170            tenant,
171        })
172    }
173
174    fn publish_trace(&self, agent: &str, tenant: Option<String>, kind: TraceKind) {
175        let _ = self.trace_events.send(TraceEvent {
176            agent: agent.to_string(),
177            tenant,
178            kind,
179        });
180    }
181
182    fn emit_tool_traces(&self, agent: &str, tenant: Option<String>, messages: &[Message]) {
183        for message in messages {
184            if let Some(call) = &message.tool_call {
185                self.publish_trace(
186                    agent,
187                    tenant.clone(),
188                    TraceKind::ToolCall {
189                        name: call.name.clone(),
190                        arguments: call.arguments.clone(),
191                    },
192                );
193            }
194            if let Some(result) = &message.tool_result {
195                self.publish_trace(
196                    agent,
197                    tenant.clone(),
198                    TraceKind::ToolResult {
199                        name: result.name.clone(),
200                        output: result.output.clone(),
201                    },
202                );
203            }
204        }
205    }
206
207    pub async fn serve(self, addr: SocketAddr) -> Result<()> {
208        let app = Router::new()
209            .route("/health", get(|| async { "ok" }))
210            .route("/metrics", get(prometheus_metrics))
211            .route("/dashboard", get(dashboard))
212            .route("/agents", get(list_agents::<M>))
213            .route("/agents/:id/chat", post(chat_with_agent::<M>))
214            .route("/agents/:id/traces", get(stream_tool_traces::<M>))
215            .route("/teams", get(list_teams::<M>))
216            .route("/workflows", get(list_workflows::<M>))
217            .route("/events", get(stream_events::<M>))
218            .route("/invoke", post(run_workflow::<M>))
219            .with_state(self.clone());
220
221        let listener = tokio::net::TcpListener::bind(addr).await?;
222        axum::serve(listener, app.into_make_service())
223            .await
224            .map_err(|err| crate::error::AgnoError::Protocol(format!("server error: {err}")))?;
225        Ok(())
226    }
227}
228
229#[derive(Serialize)]
230struct TeamSummary {
231    name: String,
232    agents: usize,
233}
234
235#[derive(Serialize)]
236struct AgentSummary {
237    name: String,
238    tools: usize,
239}
240
241#[derive(Deserialize)]
242struct AgentChatRequest {
243    message: String,
244    #[serde(default)]
245    principal_id: Option<String>,
246    #[serde(default)]
247    role: Option<String>,
248    #[serde(default)]
249    tenant: Option<String>,
250}
251
252#[derive(Serialize)]
253struct AgentChatResponse {
254    reply: String,
255    transcript: Vec<Message>,
256}
257
258#[derive(Deserialize, Default)]
259struct TraceAuth {
260    tenant: Option<String>,
261    principal_id: Option<String>,
262    role: Option<String>,
263}
264
265#[derive(Debug, Clone, Serialize)]
266struct TraceEvent {
267    agent: String,
268    tenant: Option<String>,
269    kind: TraceKind,
270}
271
272#[derive(Debug, Clone, Serialize)]
273#[serde(tag = "kind", rename_all = "snake_case")]
274enum TraceKind {
275    Started { message: String },
276    ToolCall { name: String, arguments: Value },
277    ToolResult { name: String, output: Value },
278    Completed { reply: String },
279    Failed { error: String },
280}
281
282fn json_error(status: StatusCode, message: &str) -> Response {
283    (status, Json(json!({"error": message}))).into_response()
284}
285
286fn parse_role(raw: Option<&str>) -> GovernanceRole {
287    match raw.unwrap_or("user").to_lowercase().as_str() {
288        "admin" => GovernanceRole::Admin,
289        "service" => GovernanceRole::Service,
290        _ => GovernanceRole::User,
291    }
292}
293
294async fn list_teams<M: LanguageModel + 'static>(
295    State(state): State<AgentRuntime<M>>,
296) -> impl IntoResponse {
297    let teams = state.teams.read().await;
298    let payload: Vec<TeamSummary> = teams
299        .iter()
300        .map(|(name, team)| TeamSummary {
301            name: name.clone(),
302            agents: team.size(),
303        })
304        .collect();
305    Json(payload)
306}
307
308async fn list_agents<M: LanguageModel + 'static>(
309    State(state): State<AgentRuntime<M>>,
310) -> impl IntoResponse {
311    let agents = state.agents.read().await;
312    let mut payload = Vec::new();
313    for (name, agent) in agents.iter() {
314        let guard = agent.lock().await;
315        payload.push(AgentSummary {
316            name: name.clone(),
317            tools: guard.tool_names().len(),
318        });
319    }
320    Json(payload)
321}
322
323#[derive(Serialize)]
324struct WorkflowSummary {
325    name: String,
326}
327
328async fn list_workflows<M: LanguageModel + 'static>(
329    State(state): State<AgentRuntime<M>>,
330) -> impl IntoResponse {
331    let flows = state.workflows.read().await;
332    let payload: Vec<WorkflowSummary> = flows
333        .iter()
334        .map(|(name, _)| WorkflowSummary { name: name.clone() })
335        .collect();
336    Json(payload)
337}
338
339async fn stream_events<M: LanguageModel + 'static>(
340    State(state): State<AgentRuntime<M>>,
341) -> Sse<impl Stream<Item = std::result::Result<Event, Infallible>>> {
342    let rx = state.events.subscribe();
343    let stream = BroadcastStream::new(rx).filter_map(|msg| async move {
344        match msg {
345            Ok(line) => Some(Ok::<Event, Infallible>(Event::default().data(line))),
346            Err(_) => None,
347        }
348    });
349    Sse::new(stream)
350}
351
352async fn stream_tool_traces<M: LanguageModel + 'static>(
353    State(state): State<AgentRuntime<M>>,
354    Path(agent_id): Path<String>,
355    Query(auth): Query<TraceAuth>,
356    headers: HeaderMap,
357) -> Response {
358    let tenant = match state.resolve_tenant(&headers, &auth.tenant) {
359        Ok(t) => t,
360        Err(resp) => return resp,
361    };
362    let principal = match state.build_principal(
363        &headers,
364        &AgentChatRequest {
365            message: String::new(),
366            principal_id: auth.principal_id.clone(),
367            role: auth.role.clone(),
368            tenant: tenant.clone(),
369        },
370    ) {
371        Ok(principal) => principal,
372        Err(resp) => return resp,
373    };
374    if !state
375        .access_control
376        .authorize(&principal, &Action::ReadTranscript)
377    {
378        return json_error(
379            StatusCode::FORBIDDEN,
380            "principal not authorized to read traces",
381        );
382    }
383
384    state.telemetry.record(
385        "sse_subscribe",
386        json!({"agent": agent_id.clone(), "tenant": principal.tenant, "principal": principal.id}),
387        crate::TelemetryLabels::default().with_tenant(principal.tenant.clone().unwrap_or_default()),
388    );
389
390    let rx = state.trace_events.subscribe();
391    let stream = BroadcastStream::new(rx).filter_map(move |msg| {
392        let tenant = tenant.clone();
393        let agent_id = agent_id.clone();
394        async move {
395            match msg {
396                Ok(event) => {
397                    if event.agent != agent_id {
398                        return None;
399                    }
400                    if let Some(ref t) = tenant {
401                        if event.tenant.as_ref() != Some(t) {
402                            return None;
403                        }
404                    }
405                    serde_json::to_string(&event)
406                        .ok()
407                        .map(|payload| Ok::<Event, Infallible>(Event::default().data(payload)))
408                }
409                Err(_) => None,
410            }
411        }
412    });
413    Sse::new(stream).into_response()
414}
415
416#[derive(serde::Deserialize)]
417struct WorkflowRequest {
418    name: String,
419}
420
421async fn run_workflow<M: LanguageModel + 'static>(
422    State(state): State<AgentRuntime<M>>,
423    Json(req): Json<WorkflowRequest>,
424) -> Response {
425    let flow = { state.workflows.read().await.get(&req.name).cloned() };
426    if let Some(flow) = flow {
427        let mut ctx = crate::WorkflowContext::default();
428        match flow.run(&mut ctx).await {
429            Ok(value) => {
430                let _ = state
431                    .events
432                    .send(format!("workflow:{} completed", flow.name));
433                Json(json!({ "result": value, "state": ctx.state })).into_response()
434            }
435            Err(err) => (
436                axum::http::StatusCode::INTERNAL_SERVER_ERROR,
437                Json(json!({"error": err.to_string()})),
438            )
439                .into_response(),
440        }
441    } else {
442        (
443            axum::http::StatusCode::NOT_FOUND,
444            Json(json!({"error":"workflow not found"})),
445        )
446            .into_response()
447    }
448}
449
450async fn chat_with_agent<M: LanguageModel + 'static>(
451    State(state): State<AgentRuntime<M>>,
452    Path(agent_id): Path<String>,
453    headers: HeaderMap,
454    Json(req): Json<AgentChatRequest>,
455) -> Response {
456    let principal = match state.build_principal(&headers, &req) {
457        Ok(p) => p,
458        Err(resp) => return resp,
459    };
460
461    if !state
462        .access_control
463        .authorize(&principal, &Action::SendMessage)
464    {
465        return json_error(
466            StatusCode::FORBIDDEN,
467            "principal not authorized to message this agent",
468        );
469    }
470
471    let agent = { state.agents.read().await.get(&agent_id).cloned() };
472    let Some(agent) = agent else {
473        return json_error(StatusCode::NOT_FOUND, "agent not registered");
474    };
475
476    state.telemetry.record(
477        "http_request",
478        json!({"path": format!("/agents/{}/chat", agent_id), "tenant": principal.tenant, "principal": principal.id}),
479        crate::TelemetryLabels::default().with_tenant(principal.tenant.clone().unwrap_or_default()),
480    );
481
482    let mut guard = agent.lock().await;
483    guard.set_principal(principal.clone());
484    guard.attach_access_control(Arc::new(state.access_control.clone()));
485    guard.attach_metrics(state.metrics.clone());
486    guard.attach_telemetry(state.telemetry.clone());
487
488    let starting_len = guard.memory().len();
489    state.publish_trace(
490        &agent_id,
491        principal.tenant.clone(),
492        TraceKind::Started {
493            message: req.message.clone(),
494        },
495    );
496
497    let result = guard
498        .respond_for(principal.clone(), req.message.clone())
499        .await;
500    let transcript: Vec<Message> = guard.memory().iter().cloned().collect();
501    let new_segment: Vec<Message> = guard.memory().iter().skip(starting_len).cloned().collect();
502    drop(guard);
503
504    state.emit_tool_traces(&agent_id, principal.tenant.clone(), &new_segment);
505
506    match result {
507        Ok(reply) => {
508            state.publish_trace(
509                &agent_id,
510                principal.tenant.clone(),
511                TraceKind::Completed {
512                    reply: reply.clone(),
513                },
514            );
515            state.telemetry.record(
516                "http_response",
517                json!({"path": format!("/agents/{}/chat", agent_id), "status": 200, "tenant": principal.tenant}),
518                crate::TelemetryLabels::default().with_tenant(principal.tenant.clone().unwrap_or_default()),
519            );
520            Json(AgentChatResponse { reply, transcript }).into_response()
521        }
522        Err(err) => {
523            state.publish_trace(
524                &agent_id,
525                principal.tenant.clone(),
526                TraceKind::Failed {
527                    error: err.to_string(),
528                },
529            );
530            state.telemetry.record(
531                "http_response",
532                json!({"path": format!("/agents/{}/chat", agent_id), "status": 502, "error": err.to_string()}),
533                crate::TelemetryLabels::default().with_tenant(principal.tenant.clone().unwrap_or_default()),
534            );
535            (
536                StatusCode::BAD_GATEWAY,
537                Json(json!({"error": err.to_string()})),
538            )
539                .into_response()
540        }
541    }
542}
543
544async fn prometheus_metrics() -> impl IntoResponse {
545    use prometheus::Encoder;
546    let registry = crate::metrics::init_prometheus_registry();
547    let metric_families = registry.gather();
548    let encoder = prometheus::TextEncoder::new();
549    let mut buffer = Vec::new();
550    if encoder.encode(&metric_families, &mut buffer).is_err() {
551        return (
552            StatusCode::INTERNAL_SERVER_ERROR,
553            "failed to encode metrics",
554        )
555            .into_response();
556    }
557    (
558        [(axum::http::header::CONTENT_TYPE, "text/plain; charset=utf-8")],
559        buffer,
560    )
561        .into_response()
562}
563
564async fn dashboard() -> Html<&'static str> {
565    Html(
566        r#"
567<!doctype html>
568<html>
569<head>
570    <meta charset="utf-8" />
571    <title>AgentOS Dashboard</title>
572    <style>
573        body { font-family: sans-serif; margin: 2rem; }
574        .column { float: left; width: 45%; margin-right: 5%; }
575        .panel { border: 1px solid #ccc; padding: 1rem; margin-bottom: 1rem; border-radius: 8px; }
576        h2 { margin-top: 0; }
577        #events { background: #111; color: #0f0; height: 200px; overflow: auto; font-family: monospace; padding: 1rem; }
578        #trace-log { background: #f6f8fa; height: 200px; overflow: auto; font-family: monospace; padding: 0.75rem; }
579        #chat-output { background: #f6f8fa; min-height: 80px; padding: 0.75rem; white-space: pre-wrap; }
580        label { display: block; margin-top: 0.5rem; font-weight: 600; }
581        input, select, textarea { width: 100%; padding: 0.35rem; margin-top: 0.25rem; }
582        button { margin-top: 0.5rem; padding: 0.5rem 1rem; }
583    </style>
584</head>
585<body>
586    <h1>AgentOS</h1>
587    <div class="column">
588        <div class="panel">
589            <h2>Teams</h2>
590            <ul id="teams"></ul>
591        </div>
592        <div class="panel">
593            <h2>Workflows</h2>
594            <ul id="workflows"></ul>
595        </div>
596    </div>
597    <div class="column">
598        <div class="panel">
599            <h2>Events</h2>
600            <div id="events"></div>
601        </div>
602    </div>
603    <div class="column">
604        <div class="panel">
605            <h2>Agents</h2>
606            <select id="agent-select"></select>
607            <div id="agent-tools"></div>
608        </div>
609        <div class="panel">
610            <h2>Chat</h2>
611            <label for="tenant">Tenant (x-tenant)</label>
612            <input id="tenant" placeholder="acme-co" />
613            <label for="principal">Principal ID</label>
614            <input id="principal" placeholder="user-123" />
615            <label for="role">Role</label>
616            <select id="role">
617                <option value="user">User</option>
618                <option value="admin">Admin</option>
619                <option value="service">Service</option>
620            </select>
621            <label for="chat-input">Message</label>
622            <textarea id="chat-input" placeholder="Ask an agent..."></textarea>
623            <button onclick="sendChat()">Send</button>
624            <div id="chat-output"></div>
625        </div>
626        <div class="panel">
627            <h2>Tool traces (SSE)</h2>
628            <div id="trace-log"></div>
629        </div>
630    </div>
631    <script>
632        let traceSource = null;
633        async function load() {
634            const teams = await fetch('/teams').then(r => r.json());
635            document.getElementById('teams').innerHTML = teams.map(t => `<li>${t.name} (${t.agents} agents)</li>`).join('');
636            const workflows = await fetch('/workflows').then(r => r.json());
637            document.getElementById('workflows').innerHTML = workflows.map(w => `<li>${w.name}</li>`).join('');
638            await refreshAgents();
639        }
640        load();
641        const evt = new EventSource('/events');
642        evt.onmessage = (ev) => {
643            const node = document.getElementById('events');
644            node.innerText += ev.data + "\n";
645            node.scrollTop = node.scrollHeight;
646        };
647
648        async function refreshAgents() {
649            const agents = await fetch('/agents').then(r => r.json());
650            const select = document.getElementById('agent-select');
651            select.innerHTML = agents.map(a => `<option value="${a.name}">${a.name} (${a.tools} tools)</option>`).join('');
652            if (agents.length) {
653                document.getElementById('agent-tools').innerText = `Tools: ${agents[0].tools}`;
654                select.value = agents[0].name;
655                subscribeTraces();
656            } else {
657                document.getElementById('agent-tools').innerText = 'No agents registered.';
658            }
659            select.onchange = () => {
660                const selected = agents.find(a => a.name === select.value);
661                document.getElementById('agent-tools').innerText = selected ? `Tools: ${selected.tools}` : '';
662                subscribeTraces();
663            };
664        }
665
666        function subscribeTraces() {
667            const agent = document.getElementById('agent-select').value;
668            if (!agent) return;
669            if (traceSource) traceSource.close();
670            const params = new URLSearchParams();
671            const tenant = document.getElementById('tenant').value;
672            const principal = document.getElementById('principal').value;
673            const role = document.getElementById('role').value;
674            if (tenant) params.append('tenant', tenant);
675            if (principal) params.append('principal_id', principal);
676            if (role) params.append('role', role);
677            const url = `/agents/${agent}/traces${params.toString() ? '?' + params.toString() : ''}`;
678            traceSource = new EventSource(url);
679            traceSource.onmessage = (ev) => {
680                const log = document.getElementById('trace-log');
681                try {
682                    const data = JSON.parse(ev.data);
683                    log.innerText += `[${data.kind}] ${JSON.stringify(data)}\n`;
684                } catch (e) {
685                    log.innerText += ev.data + "\n";
686                }
687                log.scrollTop = log.scrollHeight;
688            };
689        }
690
691        async function sendChat() {
692            const agent = document.getElementById('agent-select').value;
693            const message = document.getElementById('chat-input').value;
694            const tenant = document.getElementById('tenant').value;
695            const principal = document.getElementById('principal').value;
696            const role = document.getElementById('role').value;
697            if (!agent || !message) {
698                alert('Select an agent and enter a message.');
699                return;
700            }
701            const headers = {'Content-Type': 'application/json'};
702            if (tenant) headers['X-Tenant'] = tenant;
703            if (principal) headers['X-Principal-Id'] = principal;
704            if (role) headers['X-Principal-Role'] = role;
705            const payload = {message, tenant, principal_id: principal, role};
706            const res = await fetch(`/agents/${agent}/chat`, {
707                method: 'POST',
708                headers,
709                body: JSON.stringify(payload),
710            });
711            const body = await res.json();
712            if (res.ok) {
713                document.getElementById('chat-output').innerText = `Reply: ${body.reply}`;
714                subscribeTraces();
715            } else {
716                document.getElementById('chat-output').innerText = `Error: ${body.error}`;
717            }
718        }
719    </script>
720</body>
721</html>
722"#,
723    )
724}