agent_kernel/
call.rs

1//! Call message execution pipeline.
2
3use std::fmt;
4use std::net::SocketAddr;
5use std::sync::{Arc, Mutex};
6
7use agent_adapters::traits::{
8    AdapterError, InferenceRequest, MessageRole, ModelAdapter, PromptMessage,
9};
10use agent_memory::{MemoryBus, MemoryChannel, MemoryError, MemoryRecord};
11use agent_policy::{
12    DecisionKind, PolicyAction, PolicyDecision, PolicyEngine, PolicyError, PolicyRequest,
13};
14use agent_primitives::AgentId;
15use agent_tools::registry::{ToolBinding, ToolError, ToolRegistry, descriptor_from_type_name};
16use async_trait::async_trait;
17use bytes::Bytes;
18use futures::StreamExt;
19use mxp::{Message, MessageType, TransportHandle};
20use serde::Deserialize;
21use serde_json::{Value, json};
22use tokio::task;
23use tracing::{debug, info, warn};
24
25use crate::{HandlerContext, HandlerError, HandlerResult};
26
27/// Emits MXP audit events when policy decisions deny or escalate requests.
28pub trait AuditEmitter: Send + Sync {
29    /// Emits the supplied MXP event message.
30    fn emit(&self, message: Message);
31}
32
33/// Fan-out emitter that broadcasts audit events to multiple sinks.
34pub struct CompositeAuditEmitter {
35    emitters: Vec<Arc<dyn AuditEmitter>>,
36}
37
38impl CompositeAuditEmitter {
39    /// Creates a composite emitter from the supplied iterator.
40    #[must_use]
41    pub fn new<I>(emitters: I) -> Self
42    where
43        I: IntoIterator<Item = Arc<dyn AuditEmitter>>,
44    {
45        Self {
46            emitters: emitters.into_iter().collect(),
47        }
48    }
49
50    /// Adds an emitter to the composite collection.
51    pub fn push(&mut self, emitter: Arc<dyn AuditEmitter>) {
52        self.emitters.push(emitter);
53    }
54}
55
56impl AuditEmitter for CompositeAuditEmitter {
57    fn emit(&self, message: Message) {
58        for emitter in &self.emitters {
59            emitter.emit(message.clone());
60        }
61    }
62}
63
64/// Tracing-based audit emitter that logs MXP audit events.
65#[derive(Default)]
66pub struct TracingAuditEmitter;
67
68impl AuditEmitter for TracingAuditEmitter {
69    fn emit(&self, message: Message) {
70        let payload = String::from_utf8_lossy(message.payload());
71        info!(
72            event = ?message.message_type(),
73            payload = %payload,
74            "policy audit event emitted"
75        );
76    }
77}
78
79/// Sends audit events to a remote governance agent using MXP transport.
80#[derive(Clone)]
81pub struct GovernanceAuditEmitter {
82    transport: TransportHandle,
83    target: SocketAddr,
84}
85
86impl GovernanceAuditEmitter {
87    /// Creates a new governance emitter.
88    #[must_use]
89    pub fn new(transport: TransportHandle, target: SocketAddr) -> Self {
90        Self { transport, target }
91    }
92}
93
94impl AuditEmitter for GovernanceAuditEmitter {
95    fn emit(&self, message: Message) {
96        let transport = self.transport.clone();
97        let target = self.target;
98        let msg_type = message.message_type();
99        let message_id = message.message_id();
100        let trace_id = message.trace_id();
101        let encoded = message.encode();
102
103        task::spawn(async move {
104            if let Err(err) = transport.send(&encoded, target) {
105                warn!(
106                    ?err,
107                    %target,
108                    message_id,
109                    trace_id,
110                    ?msg_type,
111                    "failed to deliver governance audit event",
112                );
113            }
114        });
115    }
116}
117
118/// Observer invoked whenever a policy decision is produced.
119pub trait PolicyObserver: Send + Sync {
120    /// Records the decision emitted for the supplied request subject.
121    fn on_decision(&self, request: &PolicyRequest, decision: &PolicyDecision, subject: &str);
122}
123
124/// Observer that emits decisions to the tracing system.
125#[derive(Default)]
126pub struct TracingPolicyObserver;
127
128impl PolicyObserver for TracingPolicyObserver {
129    fn on_decision(&self, request: &PolicyRequest, decision: &PolicyDecision, subject: &str) {
130        let reason = decision.reason().unwrap_or_default();
131        let approvers = decision.required_approvals();
132        match decision.kind() {
133            DecisionKind::Allow => {
134                debug!(agent_id = %request.agent_id(), subject, "policy allow");
135            }
136            DecisionKind::Deny => {
137                warn!(
138                    agent_id = %request.agent_id(),
139                    subject,
140                    reason,
141                    "policy deny"
142                );
143            }
144            DecisionKind::Escalate => {
145                warn!(
146                    agent_id = %request.agent_id(),
147                    subject,
148                    reason,
149                    approvers = ?approvers,
150                    "policy escalate"
151                );
152            }
153        }
154    }
155}
156
157/// Composite observer that forwards decisions to a collection of observers.
158pub struct CompositePolicyObserver {
159    observers: Vec<Arc<dyn PolicyObserver>>,
160}
161
162impl CompositePolicyObserver {
163    /// Creates a new composite observer from the supplied list.
164    #[must_use]
165    pub fn new<I>(observers: I) -> Self
166    where
167        I: IntoIterator<Item = Arc<dyn PolicyObserver>>,
168    {
169        Self {
170            observers: observers.into_iter().collect(),
171        }
172    }
173
174    /// Adds an observer to the composite set.
175    pub fn push(&mut self, observer: Arc<dyn PolicyObserver>) {
176        self.observers.push(observer);
177    }
178}
179
180impl PolicyObserver for CompositePolicyObserver {
181    fn on_decision(&self, request: &PolicyRequest, decision: &PolicyDecision, subject: &str) {
182        for observer in &self.observers {
183            observer.on_decision(request, decision, subject);
184        }
185    }
186}
187
188/// Observer that emits MXP audit events for deny/escalate outcomes.
189pub struct MxpAuditObserver {
190    emitter: Arc<dyn AuditEmitter>,
191}
192
193impl MxpAuditObserver {
194    /// Creates a new MXP audit observer using the provided emitter.
195    #[must_use]
196    pub fn new(emitter: Arc<dyn AuditEmitter>) -> Self {
197        Self { emitter }
198    }
199}
200
201impl PolicyObserver for MxpAuditObserver {
202    fn on_decision(&self, request: &PolicyRequest, decision: &PolicyDecision, subject: &str) {
203        if matches!(decision.kind(), DecisionKind::Deny | DecisionKind::Escalate) {
204            let payload = json!({
205                "agent_id": request.agent_id().to_string(),
206                "subject": subject,
207                "decision": format!("{:?}", decision.kind()),
208                "reason": decision.reason(),
209                "approvers": decision.required_approvals(),
210                "metadata": request.context().metadata(),
211            });
212            let payload_string = payload.to_string();
213            let message = Message::new(MessageType::Event, payload_string.as_bytes());
214            self.emitter.emit(message);
215        }
216    }
217}
218
219/// Executes MXP `Call` messages by invoking registered tools and the
220/// configured [`ModelAdapter`].
221#[derive(Clone)]
222pub struct CallExecutor {
223    adapter: Arc<dyn ModelAdapter>,
224    tools: Arc<ToolRegistry>,
225    policy: Option<Arc<dyn PolicyEngine>>,
226    policy_observer: Option<Arc<dyn PolicyObserver>>,
227}
228
229impl fmt::Debug for CallExecutor {
230    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
231        let metadata = self.adapter.metadata();
232        f.debug_struct("CallExecutor")
233            .field("provider", &metadata.provider())
234            .field("model", &metadata.model())
235            .field("policy_configured", &self.policy.is_some())
236            .field("observer_configured", &self.policy_observer.is_some())
237            .finish_non_exhaustive()
238    }
239}
240
241impl CallExecutor {
242    /// Creates a new call executor.
243    #[must_use]
244    pub fn new(adapter: Arc<dyn ModelAdapter>, tools: Arc<ToolRegistry>) -> Self {
245        Self {
246            adapter,
247            tools,
248            policy: None,
249            policy_observer: None,
250        }
251    }
252
253    /// Configures the policy engine used for governance decisions.
254    pub fn set_policy(&mut self, policy: Arc<dyn PolicyEngine>) {
255        self.policy = Some(policy);
256    }
257
258    /// Configures the policy engine, returning the updated executor for chaining.
259    #[must_use]
260    pub fn with_policy(mut self, policy: Arc<dyn PolicyEngine>) -> Self {
261        self.set_policy(policy);
262        self
263    }
264
265    /// Returns the policy engine if one has been configured.
266    #[must_use]
267    pub fn policy(&self) -> Option<&Arc<dyn PolicyEngine>> {
268        self.policy.as_ref()
269    }
270
271    /// Installs a policy observer for integration hooks.
272    pub fn set_policy_observer(&mut self, observer: Arc<dyn PolicyObserver>) {
273        self.policy_observer = Some(observer);
274    }
275
276    /// Configures a policy observer, returning the updated executor for chaining.
277    #[must_use]
278    pub fn with_policy_observer(mut self, observer: Arc<dyn PolicyObserver>) -> Self {
279        self.set_policy_observer(observer);
280        self
281    }
282
283    /// Returns the policy observer if configured.
284    #[must_use]
285    pub fn policy_observer(&self) -> Option<&Arc<dyn PolicyObserver>> {
286        self.policy_observer.as_ref()
287    }
288
289    fn notify_policy(&self, request: &PolicyRequest, decision: &PolicyDecision, subject: &str) {
290        if let Some(observer) = &self.policy_observer {
291            observer.on_decision(request, decision, subject);
292        }
293    }
294
295    async fn enforce_tool_policy(
296        &self,
297        ctx: &HandlerContext,
298        invocation: &ToolInvocation,
299    ) -> HandlerResult<()> {
300        let Some(policy) = self.policy.as_ref() else {
301            return Ok(());
302        };
303
304        let mut request = PolicyRequest::new(
305            ctx.agent_id(),
306            PolicyAction::InvokeTool {
307                name: invocation.name.clone(),
308            },
309        );
310
311        request
312            .context_mut()
313            .insert_metadata("input", invocation.input.clone());
314
315        if let Some(handle) = self.tools.get(&invocation.name) {
316            let metadata = handle.metadata().clone();
317            request
318                .context_mut()
319                .insert_metadata("tool_version", Value::from(metadata.version().to_owned()));
320            if let Some(description) = metadata.description() {
321                request
322                    .context_mut()
323                    .insert_metadata("tool_description", Value::from(description.to_owned()));
324            }
325
326            if !metadata.capabilities().is_empty() {
327                let capabilities: Vec<String> = metadata
328                    .capabilities()
329                    .iter()
330                    .map(|cap| cap.as_str().to_owned())
331                    .collect();
332                request
333                    .context_mut()
334                    .insert_metadata("capabilities", Value::from(capabilities.clone()));
335                request
336                    .context_mut()
337                    .extend_tags(capabilities.iter().map(|cap| format!("cap:{cap}")));
338            }
339        }
340
341        let decision = policy
342            .evaluate(&request)
343            .await
344            .map_err(|err| map_policy_error(&err))?;
345
346        self.notify_policy(&request, &decision, &request.action().label());
347        enforce_decision(&decision, &request.action().label())
348    }
349
350    async fn enforce_inference_policy(
351        &self,
352        ctx: &HandlerContext,
353        message_count: usize,
354        tool_names: &[String],
355    ) -> HandlerResult<()> {
356        let Some(policy) = self.policy.as_ref() else {
357            return Ok(());
358        };
359
360        let metadata = self.adapter.metadata();
361        let mut request = PolicyRequest::new(
362            ctx.agent_id(),
363            PolicyAction::ModelInference {
364                provider: metadata.provider().to_owned(),
365                model: metadata.model().to_owned(),
366            },
367        );
368
369        request
370            .context_mut()
371            .insert_metadata("message_count", Value::from(message_count as u64));
372
373        if !tool_names.is_empty() {
374            request
375                .context_mut()
376                .insert_metadata("tools", Value::from(tool_names.to_owned()));
377            request
378                .context_mut()
379                .extend_tags(tool_names.iter().map(|name| format!("tool:{name}")));
380        }
381
382        let decision = policy
383            .evaluate(&request)
384            .await
385            .map_err(|err| map_policy_error(&err))?;
386
387        self.notify_policy(&request, &decision, &request.action().label());
388        enforce_decision(&decision, &request.action().label())
389    }
390
391    /// Executes the call pipeline using data extracted from the handler context.
392    ///
393    /// # Errors
394    ///
395    /// Returns [`HandlerError`] when payload decoding, tool execution, or model
396    /// inference fails.
397    pub async fn execute(&self, ctx: &HandlerContext) -> HandlerResult<CallOutcome> {
398        let payload = parse_payload(ctx)?;
399
400        let mut messages = payload.messages;
401        let mut tool_names = Vec::new();
402        let mut tool_results = Vec::new();
403
404        for invocation in payload.tools {
405            self.enforce_tool_policy(ctx, &invocation).await?;
406
407            let tool_output = self
408                .tools
409                .invoke(&invocation.name, invocation.input.clone())
410                .await
411                .map_err(|err| map_tool_error(&invocation.name, &err))?;
412
413            let message_content =
414                serde_json::to_string(&tool_output).unwrap_or_else(|_| String::new());
415            messages.push(PromptMessage::new(MessageRole::Tool, message_content));
416            tool_names.push(invocation.name.clone());
417            tool_results.push(ToolInvocationResult {
418                name: invocation.name,
419                output: tool_output,
420            });
421        }
422
423        self.enforce_inference_policy(ctx, messages.len(), &tool_names)
424            .await?;
425
426        let mut request = InferenceRequest::new(messages)
427            .map_err(|err| HandlerError::custom(format!("invalid request: {err}")))?;
428
429        if let Some(max_tokens) = payload.max_output_tokens {
430            request = request.with_max_output_tokens(max_tokens);
431        }
432
433        if let Some(temperature) = payload.temperature {
434            request = request.with_temperature(temperature);
435        }
436
437        if !tool_names.is_empty() {
438            request = request.with_tools(tool_names);
439        }
440
441        let mut stream = self
442            .adapter
443            .infer(request)
444            .await
445            .map_err(|err| map_adapter_error(&err, self.adapter.metadata()))?;
446
447        let mut response = String::new();
448        while let Some(chunk) = stream.next().await {
449            let chunk = chunk.map_err(|err| map_adapter_error(&err, self.adapter.metadata()))?;
450            response.push_str(&chunk.delta);
451            if chunk.done {
452                break;
453            }
454        }
455
456        Ok(CallOutcome {
457            response,
458            tool_results,
459        })
460    }
461}
462
463fn parse_payload(ctx: &HandlerContext) -> HandlerResult<CallPayload> {
464    let payload = ctx.message().payload();
465    if payload.is_empty() {
466        return Err(HandlerError::custom("call payload missing"));
467    }
468
469    serde_json::from_slice::<CallPayload>(payload.as_ref())
470        .map_err(|err| HandlerError::custom(format!("failed to decode call payload: {err}")))
471}
472
473fn map_tool_error(name: &str, err: &ToolError) -> HandlerError {
474    HandlerError::custom(format!("tool `{name}` failed: {err}"))
475}
476
477fn map_adapter_error(
478    err: &AdapterError,
479    metadata: &agent_adapters::traits::AdapterMetadata,
480) -> HandlerError {
481    HandlerError::custom(format!(
482        "adapter `{}` for model `{}` error: {err}",
483        metadata.provider(),
484        metadata.model()
485    ))
486}
487
488fn map_memory_error(err: &MemoryError) -> HandlerError {
489    HandlerError::custom(format!("memory error: {err}"))
490}
491
492fn map_policy_error(err: &PolicyError) -> HandlerError {
493    HandlerError::custom(format!("policy engine error: {err}"))
494}
495
496fn enforce_decision(decision: &PolicyDecision, subject: &str) -> HandlerResult<()> {
497    match decision.kind() {
498        DecisionKind::Allow => Ok(()),
499        DecisionKind::Deny => {
500            let reason = decision.reason().unwrap_or("policy denied the request");
501            Err(HandlerError::custom(format!(
502                "policy denied {subject}: {reason}"
503            )))
504        }
505        DecisionKind::Escalate => {
506            let reason = decision.reason().unwrap_or("policy escalation required");
507            let approvers = decision.required_approvals();
508            let detail = if approvers.is_empty() {
509                reason.to_owned()
510            } else {
511                format!("{reason} (approvers: {})", approvers.join(", "))
512            };
513            Err(HandlerError::custom(format!(
514                "policy escalation required for {subject}: {detail}"
515            )))
516        }
517    }
518}
519
520/// Outcome of processing a call message.
521#[derive(Debug)]
522pub struct CallOutcome {
523    response: String,
524    tool_results: Vec<ToolInvocationResult>,
525}
526
527impl CallOutcome {
528    /// Returns the aggregated model response text.
529    #[must_use]
530    pub fn response(&self) -> &str {
531        &self.response
532    }
533
534    /// Returns the tool invocation results that were executed as part of this call.
535    #[must_use]
536    pub fn tool_results(&self) -> &[ToolInvocationResult] {
537        &self.tool_results
538    }
539}
540
541/// Result describing an executed tool invocation.
542#[derive(Debug, Clone)]
543pub struct ToolInvocationResult {
544    /// Name of the tool that was invoked.
545    pub name: String,
546    /// Output produced by the tool.
547    pub output: Value,
548}
549
550#[derive(Debug, Deserialize)]
551struct CallPayload {
552    messages: Vec<PromptMessage>,
553    #[serde(default)]
554    temperature: Option<f32>,
555    #[serde(default)]
556    max_output_tokens: Option<u32>,
557    #[serde(default)]
558    tools: Vec<ToolInvocation>,
559}
560
561#[derive(Debug, Deserialize)]
562struct ToolInvocation {
563    name: String,
564    #[serde(default)]
565    input: Value,
566}
567
568/// Handler implementation that wires the call executor into the MXP handler trait.
569pub struct KernelMessageHandler {
570    executor: Arc<CallExecutor>,
571    sink: Arc<dyn CallOutcomeSink>,
572    memory: Option<Arc<MemoryBus>>,
573}
574
575impl KernelMessageHandler {
576    /// Creates a new handler using the provided adapter and registry.
577    #[must_use]
578    pub fn new(
579        adapter: Arc<dyn ModelAdapter>,
580        tools: Arc<ToolRegistry>,
581        sink: Arc<dyn CallOutcomeSink>,
582    ) -> Self {
583        let executor = Arc::new(CallExecutor::new(adapter, tools));
584        Self {
585            executor,
586            sink,
587            memory: None,
588        }
589    }
590
591    /// Creates a builder that automates tool registration and optional components.
592    #[must_use]
593    pub fn builder(
594        adapter: Arc<dyn ModelAdapter>,
595        sink: Arc<dyn CallOutcomeSink>,
596    ) -> KernelMessageHandlerBuilder {
597        KernelMessageHandlerBuilder::new(adapter, sink)
598    }
599
600    /// Configures the memory bus used to persist call transcripts.
601    #[must_use]
602    pub fn with_memory(mut self, memory: Arc<MemoryBus>) -> Self {
603        self.memory = Some(memory);
604        self
605    }
606
607    /// Installs or replaces the memory bus after construction.
608    pub fn set_memory(&mut self, memory: Arc<MemoryBus>) {
609        self.memory = Some(memory);
610    }
611
612    /// Configures the policy engine used to guard tool execution and model inference.
613    #[must_use]
614    pub fn with_policy(mut self, policy: Arc<dyn PolicyEngine>) -> Self {
615        self.set_policy(policy);
616        self
617    }
618
619    /// Installs or replaces the policy engine after construction.
620    pub fn set_policy(&mut self, policy: Arc<dyn PolicyEngine>) {
621        Arc::make_mut(&mut self.executor).set_policy(policy);
622    }
623
624    /// Configures the policy observer used to record governance decisions.
625    #[must_use]
626    pub fn with_policy_observer(mut self, observer: Arc<dyn PolicyObserver>) -> Self {
627        self.set_policy_observer(observer);
628        self
629    }
630
631    /// Installs or replaces the policy observer after construction.
632    pub fn set_policy_observer(&mut self, observer: Arc<dyn PolicyObserver>) {
633        Arc::make_mut(&mut self.executor).set_policy_observer(observer);
634    }
635
636    /// Returns the configured policy observer, if any.
637    #[must_use]
638    pub fn policy_observer(&self) -> Option<&Arc<dyn PolicyObserver>> {
639        self.executor.policy_observer()
640    }
641
642    /// Returns the configured memory bus, if any.
643    #[must_use]
644    pub fn memory(&self) -> Option<&Arc<MemoryBus>> {
645        self.memory.as_ref()
646    }
647
648    async fn record_inbound(&self, ctx: &HandlerContext) -> HandlerResult<()> {
649        let Some(memory) = &self.memory else {
650            return Ok(());
651        };
652
653        let record = MemoryRecord::builder(MemoryChannel::Input, ctx.message().payload().clone())
654            .tag("mxp.call")
655            .map_err(|err| map_memory_error(&err))?
656            .metadata("direction", Value::from("inbound"))
657            .metadata("message_type", Value::from("call"))
658            .metadata("agent_id", Value::from(ctx.agent_id().to_string()))
659            .build()
660            .map_err(|err| map_memory_error(&err))?;
661
662        self.enforce_memory_policy(ctx.agent_id(), &record).await?;
663        memory
664            .record(record)
665            .await
666            .map_err(|err| map_memory_error(&err))?;
667        Ok(())
668    }
669
670    async fn record_outbound(&self, agent_id: AgentId, outcome: &CallOutcome) -> HandlerResult<()> {
671        let Some(memory) = &self.memory else {
672            return Ok(());
673        };
674
675        for tool in outcome.tool_results() {
676            let payload = Bytes::from(serde_json::to_vec(&tool.output).map_err(|err| {
677                HandlerError::custom(format!("failed to encode tool output: {err}"))
678            })?);
679            let record = MemoryRecord::builder(MemoryChannel::Tool, payload)
680                .tag("mxp.call")
681                .map_err(|err| map_memory_error(&err))?
682                .tag("tool")
683                .map_err(|err| map_memory_error(&err))?
684                .metadata("direction", Value::from("tool"))
685                .metadata("tool_name", Value::from(tool.name.clone()))
686                .build()
687                .map_err(|err| map_memory_error(&err))?;
688            self.enforce_memory_policy(agent_id, &record).await?;
689            memory
690                .record(record)
691                .await
692                .map_err(|err| map_memory_error(&err))?;
693        }
694
695        let response_record = MemoryRecord::builder(
696            MemoryChannel::Output,
697            Bytes::from(outcome.response().to_owned()),
698        )
699        .tag("mxp.call")
700        .map_err(|err| map_memory_error(&err))?
701        .metadata("direction", Value::from("outbound"))
702        .metadata("message_type", Value::from("call"))
703        .build()
704        .map_err(|err| map_memory_error(&err))?;
705
706        self.enforce_memory_policy(agent_id, &response_record)
707            .await?;
708        memory
709            .record(response_record)
710            .await
711            .map_err(|err| map_memory_error(&err))?;
712        Ok(())
713    }
714
715    async fn enforce_memory_policy(
716        &self,
717        agent_id: AgentId,
718        record: &MemoryRecord,
719    ) -> HandlerResult<()> {
720        let Some(policy) = self.executor.policy() else {
721            return Ok(());
722        };
723
724        let request = PolicyRequest::from_memory_record(agent_id, record);
725        let decision = policy
726            .evaluate(&request)
727            .await
728            .map_err(|err| map_policy_error(&err))?;
729
730        self.executor
731            .notify_policy(&request, &decision, &request.action().label());
732        enforce_decision(&decision, &request.action().label())
733    }
734
735    /// Returns the underlying executor for advanced scenarios.
736    #[must_use]
737    pub fn executor(&self) -> &CallExecutor {
738        &self.executor
739    }
740}
741
742/// Builder for [`KernelMessageHandler`] that automates tool registration and optional components.
743pub struct KernelMessageHandlerBuilder {
744    adapter: Arc<dyn ModelAdapter>,
745    sink: Arc<dyn CallOutcomeSink>,
746    tools: Vec<ToolBinding>,
747    memory: Option<Arc<MemoryBus>>,
748    policy: Option<Arc<dyn PolicyEngine>>,
749    policy_observer: Option<Arc<dyn PolicyObserver>>,
750}
751
752impl KernelMessageHandlerBuilder {
753    fn new(adapter: Arc<dyn ModelAdapter>, sink: Arc<dyn CallOutcomeSink>) -> Self {
754        Self {
755            adapter,
756            sink,
757            tools: Vec::new(),
758            memory: None,
759            policy: None,
760            policy_observer: None,
761        }
762    }
763
764    /// Registers tool functions; descriptors are resolved automatically.
765    ///
766    /// # Errors
767    ///
768    /// Returns [`ToolError`] if any referenced function is missing a `#[tool]` annotation
769    /// or if the generated binding fails validation.
770    pub fn with_tools<F, I>(mut self, tools: I) -> Result<Self, ToolError>
771    where
772        F: Copy + 'static,
773        I: IntoIterator<Item = F>,
774    {
775        for tool in tools {
776            let type_name = std::any::type_name_of_val(&tool);
777            let descriptor = descriptor_from_type_name(type_name);
778            self.tools.push(descriptor.binding()?);
779        }
780        Ok(self)
781    }
782
783    /// Configures the memory bus used to persist call transcripts.
784    #[must_use]
785    pub fn with_memory(mut self, memory: Arc<MemoryBus>) -> Self {
786        self.memory = Some(memory);
787        self
788    }
789
790    /// Installs or replaces the policy engine.
791    #[must_use]
792    pub fn with_policy(mut self, policy: Arc<dyn PolicyEngine>) -> Self {
793        self.policy = Some(policy);
794        self
795    }
796
797    /// Installs or replaces the policy observer.
798    #[must_use]
799    pub fn with_policy_observer(mut self, observer: Arc<dyn PolicyObserver>) -> Self {
800        self.policy_observer = Some(observer);
801        self
802    }
803
804    /// Finalises the builder, registering tools and returning a configured handler.
805    ///
806    /// # Errors
807    ///
808    /// Returns [`ToolError`] if tool registration fails (for example, duplicate names).
809    pub fn build(self) -> Result<KernelMessageHandler, ToolError> {
810        let registry = Arc::new(ToolRegistry::new());
811        for binding in self.tools {
812            registry.register_binding(binding)?;
813        }
814
815        let mut handler = KernelMessageHandler::new(self.adapter, registry, self.sink);
816
817        if let Some(memory) = self.memory {
818            handler.set_memory(memory);
819        }
820        if let Some(policy) = self.policy {
821            handler.set_policy(policy);
822        }
823        if let Some(observer) = self.policy_observer {
824            handler.set_policy_observer(observer);
825        }
826
827        Ok(handler)
828    }
829}
830
831#[async_trait]
832impl crate::AgentMessageHandler for KernelMessageHandler {
833    async fn handle_call(&self, ctx: HandlerContext) -> HandlerResult {
834        self.record_inbound(&ctx).await?;
835
836        let outcome = self.executor.execute(&ctx).await?;
837
838        self.record_outbound(ctx.agent_id(), &outcome).await?;
839
840        self.sink.record(outcome);
841        Ok(())
842    }
843}
844
845/// Observer trait used to capture call outcomes (for logging, metrics, etc.).
846pub trait CallOutcomeSink: Send + Sync {
847    /// Records the outcome of a call invocation.
848    fn record(&self, outcome: CallOutcome);
849}
850
851/// Sink implementation that logs to tracing.
852#[derive(Default)]
853pub struct TracingCallSink;
854
855impl CallOutcomeSink for TracingCallSink {
856    fn record(&self, outcome: CallOutcome) {
857        let tool_names: Vec<String> = outcome
858            .tool_results()
859            .iter()
860            .map(|result| result.name.clone())
861            .collect();
862        tracing::info!(
863            response = outcome.response(),
864            tools = ?tool_names,
865            "call execution completed"
866        );
867    }
868}
869
870/// Sink used during testing to capture outcomes.
871#[derive(Default)]
872pub struct CollectingSink {
873    results: Mutex<Vec<CallOutcome>>,
874}
875
876impl CollectingSink {
877    /// Creates a new collecting sink.
878    #[must_use]
879    pub fn new() -> Arc<Self> {
880        Arc::new(Self {
881            results: Mutex::new(Vec::new()),
882        })
883    }
884
885    /// Returns the collected outcomes.
886    ///
887    /// # Panics
888    ///
889    /// Panics if the internal mutex has been poisoned by a previous panic.
890    #[must_use]
891    pub fn drain(&self) -> Vec<CallOutcome> {
892        let mut lock = self.results.lock().expect("collecting sink poisoned");
893        lock.drain(..).collect()
894    }
895}
896
897impl CallOutcomeSink for CollectingSink {
898    fn record(&self, outcome: CallOutcome) {
899        self.results
900            .lock()
901            .expect("collecting sink poisoned")
902            .push(outcome);
903    }
904}
905
906#[cfg(test)]
907mod tests {
908    use super::*;
909
910    use agent_adapters::traits::{AdapterMetadata, AdapterResult, AdapterStream, InferenceChunk};
911    use agent_memory::{FileJournal, MemoryBusBuilder, MemoryChannel, VolatileConfig};
912    use agent_policy::{PolicyAction, PolicyDecision, PolicyEngine, PolicyRequest, PolicyResult};
913    use agent_primitives::AgentId;
914    use agent_tools::registry::{ToolMetadata, ToolRegistry};
915    use futures::stream;
916    use mxp::Message;
917    use serde_json::json;
918    use std::io::ErrorKind;
919    use std::net::SocketAddr;
920    use std::num::NonZeroUsize;
921    use std::sync::{Arc, Mutex};
922    use std::time::Duration;
923    use tokio::sync::oneshot;
924
925    use crate::{AgentMessageHandler, HandlerContext, HandlerError};
926
927    struct StaticAdapter {
928        metadata: AdapterMetadata,
929        response: String,
930    }
931
932    #[async_trait]
933    impl ModelAdapter for StaticAdapter {
934        fn metadata(&self) -> &AdapterMetadata {
935            &self.metadata
936        }
937
938        async fn infer(&self, _request: InferenceRequest) -> AdapterResult<AdapterStream> {
939            let chunk = InferenceChunk::new(self.response.clone(), true);
940            Ok(Box::pin(stream::once(async move { Ok(chunk) })))
941        }
942    }
943
944    struct DenyPolicy;
945
946    #[async_trait]
947    impl PolicyEngine for DenyPolicy {
948        async fn evaluate(&self, request: &PolicyRequest) -> PolicyResult<PolicyDecision> {
949            match request.action() {
950                PolicyAction::InvokeTool { .. } => Ok(PolicyDecision::deny("disabled by policy")),
951                _ => Ok(PolicyDecision::allow()),
952            }
953        }
954    }
955
956    fn temp_path() -> std::path::PathBuf {
957        let mut path = std::env::temp_dir();
958        path.push(format!("handler-test-{}.log", AgentId::random()));
959        path
960    }
961
962    #[tokio::test]
963    async fn executes_call_pipeline() {
964        let adapter = Arc::new(StaticAdapter {
965            metadata: AdapterMetadata::new("test", "static"),
966            response: "static-response".to_owned(),
967        });
968        let tools = Arc::new(ToolRegistry::new());
969        tools
970            .register_tool(
971                ToolMetadata::new("echo", "1.0.0").unwrap(),
972                |input: Value| async move { Ok(input) },
973            )
974            .unwrap();
975
976        let sink = CollectingSink::new();
977        let handler = KernelMessageHandler::new(adapter, tools, sink.clone());
978
979        let payload = json!({
980            "messages": [
981                {"role": "system", "content": "You are helpful."},
982                {"role": "user", "content": "Ping"}
983            ],
984            "tools": [
985                {"name": "echo", "input": {"value": 1}}
986            ]
987        });
988
989        let message = mxp::Message::new(mxp::MessageType::Call, payload.to_string().as_bytes());
990        let ctx = HandlerContext::from_message(agent_primitives::AgentId::random(), message);
991
992        handler.handle_call(ctx).await.unwrap();
993
994        let results = sink.drain();
995        assert_eq!(results.len(), 1);
996        assert_eq!(results[0].response(), "static-response");
997        assert_eq!(results[0].tool_results().len(), 1);
998    }
999
1000    #[tokio::test]
1001    async fn policy_denies_tool_invocation() {
1002        let adapter = Arc::new(StaticAdapter {
1003            metadata: AdapterMetadata::new("test", "static"),
1004            response: "static-response".to_owned(),
1005        });
1006        let tools = Arc::new(ToolRegistry::new());
1007        tools
1008            .register_tool(
1009                ToolMetadata::new("echo", "1.0.0").unwrap(),
1010                |input: Value| async move { Ok(input) },
1011            )
1012            .unwrap();
1013
1014        let sink = CollectingSink::new();
1015        let handler = KernelMessageHandler::new(adapter, tools, sink.clone())
1016            .with_policy(Arc::new(DenyPolicy));
1017
1018        let payload = json!({
1019            "messages": [
1020                {"role": "user", "content": "ping"}
1021            ],
1022            "tools": [
1023                {"name": "echo", "input": {"value": 1}}
1024            ]
1025        });
1026
1027        let message = mxp::Message::new(mxp::MessageType::Call, payload.to_string().as_bytes());
1028        let ctx = HandlerContext::from_message(agent_primitives::AgentId::random(), message);
1029
1030        let err = handler
1031            .handle_call(ctx)
1032            .await
1033            .expect_err("policy should deny");
1034        match err {
1035            HandlerError::Custom(reason) => assert!(reason.contains("policy denied")),
1036            other => panic!("unexpected error: {other:?}"),
1037        }
1038
1039        assert!(sink.drain().is_empty());
1040    }
1041
1042    #[tokio::test]
1043    async fn persists_transcript_via_memory_bus() {
1044        let adapter = Arc::new(StaticAdapter {
1045            metadata: AdapterMetadata::new("test", "static"),
1046            response: "ok".to_owned(),
1047        });
1048
1049        let tools = Arc::new(ToolRegistry::new());
1050        tools
1051            .register_tool(
1052                ToolMetadata::new("echo", "1.0.0").unwrap(),
1053                |input: Value| async move { Ok(input) },
1054            )
1055            .unwrap();
1056
1057        let sink = CollectingSink::new();
1058        let path = temp_path();
1059        let journal: Arc<dyn agent_memory::Journal> =
1060            Arc::new(FileJournal::open(&path).await.unwrap());
1061        let bus = Arc::new(
1062            MemoryBusBuilder::new(VolatileConfig::new(NonZeroUsize::new(8).unwrap()))
1063                .with_journal(journal.clone())
1064                .build()
1065                .unwrap(),
1066        );
1067
1068        let handler =
1069            KernelMessageHandler::new(adapter, tools, sink.clone()).with_memory(bus.clone());
1070
1071        let payload = json!({
1072            "messages": [
1073                {"role": "user", "content": "hello"}
1074            ],
1075            "tools": [
1076                {"name": "echo", "input": {"value": 1}}
1077            ]
1078        });
1079
1080        let message = mxp::Message::new(mxp::MessageType::Call, payload.to_string().as_bytes());
1081        let ctx = HandlerContext::from_message(agent_primitives::AgentId::random(), message);
1082
1083        handler.handle_call(ctx).await.unwrap();
1084
1085        let records = bus.recent(5).await;
1086        assert_eq!(records.len(), 3);
1087        assert!(matches!(records[0].channel(), MemoryChannel::Input));
1088        assert!(matches!(records[1].channel(), MemoryChannel::Tool));
1089        assert!(matches!(records[2].channel(), MemoryChannel::Output));
1090
1091        if path.exists() {
1092            let _ = std::fs::remove_file(path);
1093        }
1094    }
1095
1096    struct RecordingObserver {
1097        decisions: Mutex<Vec<(String, DecisionKind)>>,
1098    }
1099
1100    impl RecordingObserver {
1101        fn new() -> Arc<Self> {
1102            Arc::new(Self {
1103                decisions: Mutex::new(Vec::new()),
1104            })
1105        }
1106    }
1107
1108    impl PolicyObserver for RecordingObserver {
1109        fn on_decision(&self, request: &PolicyRequest, decision: &PolicyDecision, subject: &str) {
1110            let mut guard = self.decisions.lock().expect("observer poisoned");
1111            guard.push((
1112                format!("{}:{}", request.agent_id(), subject),
1113                decision.kind(),
1114            ));
1115        }
1116    }
1117
1118    #[tokio::test]
1119    async fn observer_receives_decisions() {
1120        let adapter = Arc::new(StaticAdapter {
1121            metadata: AdapterMetadata::new("test", "static"),
1122            response: "ok".to_owned(),
1123        });
1124        let tools = Arc::new(ToolRegistry::new());
1125        tools
1126            .register_tool(
1127                ToolMetadata::new("echo", "1.0.0").unwrap(),
1128                |input: Value| async move { Ok(input) },
1129            )
1130            .unwrap();
1131
1132        let sink = CollectingSink::new();
1133        let observer = RecordingObserver::new();
1134        let handler = KernelMessageHandler::new(adapter, tools, sink.clone())
1135            .with_policy(Arc::new(DenyPolicy))
1136            .with_policy_observer(observer.clone());
1137
1138        let payload = json!({
1139            "messages": [
1140                {"role": "user", "content": "ping"}
1141            ],
1142            "tools": [
1143                {"name": "echo", "input": {"value": 1}}
1144            ]
1145        });
1146
1147        let message = mxp::Message::new(mxp::MessageType::Call, payload.to_string().as_bytes());
1148        let ctx = HandlerContext::from_message(agent_primitives::AgentId::random(), message);
1149
1150        handler
1151            .handle_call(ctx)
1152            .await
1153            .expect_err("policy should deny");
1154
1155        let records = observer
1156            .decisions
1157            .lock()
1158            .expect("observer poisoned")
1159            .clone();
1160        assert_eq!(records.len(), 1);
1161        assert_eq!(records[0].1, DecisionKind::Deny);
1162    }
1163
1164    struct MemoryDenyPolicy;
1165
1166    #[async_trait]
1167    impl PolicyEngine for MemoryDenyPolicy {
1168        async fn evaluate(&self, request: &PolicyRequest) -> PolicyResult<PolicyDecision> {
1169            match request.action() {
1170                PolicyAction::EmitEvent { event_type } if event_type == "memory_record" => {
1171                    Ok(PolicyDecision::deny("memory recording disabled"))
1172                }
1173                _ => Ok(PolicyDecision::allow()),
1174            }
1175        }
1176    }
1177
1178    #[tokio::test]
1179    async fn policy_denies_memory_recording() {
1180        let adapter = Arc::new(StaticAdapter {
1181            metadata: AdapterMetadata::new("test", "static"),
1182            response: "ok".to_owned(),
1183        });
1184        let tools = Arc::new(ToolRegistry::new());
1185        tools
1186            .register_tool(
1187                ToolMetadata::new("echo", "1.0.0").unwrap(),
1188                |input: Value| async move { Ok(input) },
1189            )
1190            .unwrap();
1191
1192        let sink = CollectingSink::new();
1193        let journal_path = temp_path();
1194        let journal: Arc<dyn agent_memory::Journal> =
1195            Arc::new(FileJournal::open(&journal_path).await.expect("journal"));
1196        let memory_bus = Arc::new(
1197            MemoryBusBuilder::new(VolatileConfig::default())
1198                .with_journal(journal)
1199                .build()
1200                .expect("bus"),
1201        );
1202
1203        let handler = KernelMessageHandler::new(adapter, tools, sink)
1204            .with_memory(memory_bus)
1205            .with_policy(Arc::new(MemoryDenyPolicy));
1206
1207        let payload = json!({
1208            "messages": [
1209                {"role": "user", "content": "ping"}
1210            ],
1211            "tools": [
1212                {"name": "echo", "input": {"value": 1}}
1213            ]
1214        });
1215
1216        let message = mxp::Message::new(mxp::MessageType::Call, payload.to_string().as_bytes());
1217        let ctx = HandlerContext::from_message(agent_primitives::AgentId::random(), message);
1218
1219        let err = handler
1220            .handle_call(ctx)
1221            .await
1222            .expect_err("policy should deny");
1223        match err {
1224            HandlerError::Custom(reason) => assert!(reason.contains("policy denied")),
1225            other => panic!("unexpected error: {other:?}"),
1226        }
1227
1228        if journal_path.exists() {
1229            let _ = std::fs::remove_file(&journal_path);
1230        }
1231    }
1232
1233    #[test]
1234    fn composite_audit_emitter_fans_out() {
1235        let emitter_a = RecordingAuditEmitter::new();
1236        let emitter_b = RecordingAuditEmitter::new();
1237
1238        let composite = CompositeAuditEmitter::new([
1239            Arc::clone(&emitter_a) as Arc<dyn AuditEmitter>,
1240            Arc::clone(&emitter_b) as Arc<dyn AuditEmitter>,
1241        ]);
1242
1243        let message = Message::new(MessageType::Event, b"composite");
1244        composite.emit(message);
1245
1246        let events_a = emitter_a.events.lock().expect("emitter A poisoned");
1247        assert_eq!(events_a.len(), 1);
1248        let events_b = emitter_b.events.lock().expect("emitter B poisoned");
1249        assert_eq!(events_b.len(), 1);
1250    }
1251
1252    #[tokio::test]
1253    async fn governance_emitter_transmits_over_mxp() {
1254        let transport = mxp::Transport::default();
1255        let receiver = match transport
1256            .bind("127.0.0.1:0".parse::<SocketAddr>().expect("receiver addr"))
1257        {
1258            Ok(handle) => handle,
1259            Err(mxp::transport::SocketError::Io(err))
1260                if err.kind() == ErrorKind::PermissionDenied =>
1261            {
1262                eprintln!(
1263                    "skipping governance emitter test: receiver bind requires elevated privileges"
1264                );
1265                return;
1266            }
1267            Err(err) => panic!("bind receiver: {err:?}"),
1268        };
1269        let receiver_addr = receiver.local_addr().expect("receiver local addr");
1270        let sender = match transport.bind("127.0.0.1:0".parse::<SocketAddr>().expect("sender addr"))
1271        {
1272            Ok(handle) => handle,
1273            Err(mxp::transport::SocketError::Io(err))
1274                if err.kind() == ErrorKind::PermissionDenied =>
1275            {
1276                eprintln!(
1277                    "skipping governance emitter test: sender bind requires elevated privileges"
1278                );
1279                return;
1280            }
1281            Err(err) => panic!("bind sender: {err:?}"),
1282        };
1283
1284        let emitter = GovernanceAuditEmitter::new(sender, receiver_addr);
1285        let message = Message::new(MessageType::Event, b"{\"audit\":true}");
1286
1287        emitter.emit(message.clone());
1288
1289        let (tx, rx) = oneshot::channel();
1290        let recv_handle = receiver.clone();
1291        let join = tokio::task::spawn_blocking(move || {
1292            let mut buffer = recv_handle.acquire_buffer();
1293            let result = recv_handle.receive(&mut buffer);
1294            tx.send(result.map(|(len, _)| buffer.as_slice()[..len].to_vec()))
1295                .ok();
1296        });
1297
1298        let recv_result = tokio::time::timeout(Duration::from_millis(500), rx)
1299            .await
1300            .expect("timed out waiting for audit event")
1301            .expect("audit receiver task cancelled");
1302        let bytes: Vec<u8> = match recv_result {
1303            Ok(buf) => buf,
1304            Err(err) => panic!("receiver failed to capture event: {err:?}"),
1305        };
1306
1307        join.await.expect("blocking receive failed");
1308
1309        let decoded = Message::decode(bytes).expect("decoded message");
1310        assert_eq!(decoded.message_type(), Some(MessageType::Event));
1311        assert_eq!(decoded.payload(), message.payload());
1312    }
1313
1314    struct RecordingAuditEmitter {
1315        events: Mutex<Vec<Message>>,
1316    }
1317
1318    impl RecordingAuditEmitter {
1319        fn new() -> Arc<Self> {
1320            Arc::new(Self {
1321                events: Mutex::new(Vec::new()),
1322            })
1323        }
1324    }
1325
1326    impl AuditEmitter for RecordingAuditEmitter {
1327        fn emit(&self, message: Message) {
1328            self.events.lock().expect("emitter poisoned").push(message);
1329        }
1330    }
1331
1332    struct EscalatePolicy;
1333
1334    #[async_trait]
1335    impl PolicyEngine for EscalatePolicy {
1336        async fn evaluate(&self, _request: &PolicyRequest) -> PolicyResult<PolicyDecision> {
1337            Ok(PolicyDecision::escalate(
1338                "needs approval",
1339                vec!["secops".into()],
1340            ))
1341        }
1342    }
1343
1344    #[tokio::test]
1345    async fn audit_observer_emits_event_on_escalation() {
1346        let adapter = Arc::new(StaticAdapter {
1347            metadata: AdapterMetadata::new("test", "static"),
1348            response: "ok".to_owned(),
1349        });
1350        let tools = Arc::new(ToolRegistry::new());
1351        tools
1352            .register_tool(
1353                ToolMetadata::new("echo", "1.0.0").unwrap(),
1354                |input: Value| async move { Ok(input) },
1355            )
1356            .unwrap();
1357
1358        let sink = CollectingSink::new();
1359        let emitter = RecordingAuditEmitter::new();
1360        let observer = CompositePolicyObserver::new([
1361            Arc::new(TracingPolicyObserver) as Arc<dyn PolicyObserver>,
1362            Arc::new(MxpAuditObserver::new(emitter.clone())) as Arc<dyn PolicyObserver>,
1363        ]);
1364
1365        let handler = KernelMessageHandler::new(adapter, tools, sink)
1366            .with_policy(Arc::new(EscalatePolicy))
1367            .with_policy_observer(Arc::new(observer) as Arc<dyn PolicyObserver>);
1368
1369        let payload = json!({
1370            "messages": [
1371                {"role": "user", "content": "ping"}
1372            ]
1373        });
1374
1375        let message = mxp::Message::new(mxp::MessageType::Call, payload.to_string().as_bytes());
1376        let ctx = HandlerContext::from_message(agent_primitives::AgentId::random(), message);
1377
1378        let err = handler
1379            .handle_call(ctx)
1380            .await
1381            .expect_err("policy should escalate");
1382        match err {
1383            HandlerError::Custom(reason) => assert!(reason.contains("policy escalation")),
1384            other => panic!("unexpected error: {other:?}"),
1385        }
1386
1387        let events = emitter.events.lock().expect("emitter poisoned");
1388        assert_eq!(events.len(), 1);
1389        let payload = String::from_utf8_lossy(events[0].payload());
1390        assert!(payload.contains("needs approval"));
1391        assert!(payload.contains("secops"));
1392    }
1393}