1use std::fmt;
4use std::net::SocketAddr;
5use std::sync::{Arc, Mutex};
6
7use agent_adapters::traits::{
8 AdapterError, InferenceRequest, MessageRole, ModelAdapter, PromptMessage,
9};
10use agent_memory::{MemoryBus, MemoryChannel, MemoryError, MemoryRecord};
11use agent_policy::{
12 DecisionKind, PolicyAction, PolicyDecision, PolicyEngine, PolicyError, PolicyRequest,
13};
14use agent_primitives::AgentId;
15use agent_tools::registry::{ToolBinding, ToolError, ToolRegistry, descriptor_from_type_name};
16use async_trait::async_trait;
17use bytes::Bytes;
18use futures::StreamExt;
19use mxp::{Message, MessageType, TransportHandle};
20use serde::Deserialize;
21use serde_json::{Value, json};
22use tokio::task;
23use tracing::{debug, info, warn};
24
25use crate::{HandlerContext, HandlerError, HandlerResult};
26
27pub trait AuditEmitter: Send + Sync {
29 fn emit(&self, message: Message);
31}
32
33pub struct CompositeAuditEmitter {
35 emitters: Vec<Arc<dyn AuditEmitter>>,
36}
37
38impl CompositeAuditEmitter {
39 #[must_use]
41 pub fn new<I>(emitters: I) -> Self
42 where
43 I: IntoIterator<Item = Arc<dyn AuditEmitter>>,
44 {
45 Self {
46 emitters: emitters.into_iter().collect(),
47 }
48 }
49
50 pub fn push(&mut self, emitter: Arc<dyn AuditEmitter>) {
52 self.emitters.push(emitter);
53 }
54}
55
56impl AuditEmitter for CompositeAuditEmitter {
57 fn emit(&self, message: Message) {
58 for emitter in &self.emitters {
59 emitter.emit(message.clone());
60 }
61 }
62}
63
64#[derive(Default)]
66pub struct TracingAuditEmitter;
67
68impl AuditEmitter for TracingAuditEmitter {
69 fn emit(&self, message: Message) {
70 let payload = String::from_utf8_lossy(message.payload());
71 info!(
72 event = ?message.message_type(),
73 payload = %payload,
74 "policy audit event emitted"
75 );
76 }
77}
78
79#[derive(Clone)]
81pub struct GovernanceAuditEmitter {
82 transport: TransportHandle,
83 target: SocketAddr,
84}
85
86impl GovernanceAuditEmitter {
87 #[must_use]
89 pub fn new(transport: TransportHandle, target: SocketAddr) -> Self {
90 Self { transport, target }
91 }
92}
93
94impl AuditEmitter for GovernanceAuditEmitter {
95 fn emit(&self, message: Message) {
96 let transport = self.transport.clone();
97 let target = self.target;
98 let msg_type = message.message_type();
99 let message_id = message.message_id();
100 let trace_id = message.trace_id();
101 let encoded = message.encode();
102
103 task::spawn(async move {
104 if let Err(err) = transport.send(&encoded, target) {
105 warn!(
106 ?err,
107 %target,
108 message_id,
109 trace_id,
110 ?msg_type,
111 "failed to deliver governance audit event",
112 );
113 }
114 });
115 }
116}
117
118pub trait PolicyObserver: Send + Sync {
120 fn on_decision(&self, request: &PolicyRequest, decision: &PolicyDecision, subject: &str);
122}
123
124#[derive(Default)]
126pub struct TracingPolicyObserver;
127
128impl PolicyObserver for TracingPolicyObserver {
129 fn on_decision(&self, request: &PolicyRequest, decision: &PolicyDecision, subject: &str) {
130 let reason = decision.reason().unwrap_or_default();
131 let approvers = decision.required_approvals();
132 match decision.kind() {
133 DecisionKind::Allow => {
134 debug!(agent_id = %request.agent_id(), subject, "policy allow");
135 }
136 DecisionKind::Deny => {
137 warn!(
138 agent_id = %request.agent_id(),
139 subject,
140 reason,
141 "policy deny"
142 );
143 }
144 DecisionKind::Escalate => {
145 warn!(
146 agent_id = %request.agent_id(),
147 subject,
148 reason,
149 approvers = ?approvers,
150 "policy escalate"
151 );
152 }
153 }
154 }
155}
156
157pub struct CompositePolicyObserver {
159 observers: Vec<Arc<dyn PolicyObserver>>,
160}
161
162impl CompositePolicyObserver {
163 #[must_use]
165 pub fn new<I>(observers: I) -> Self
166 where
167 I: IntoIterator<Item = Arc<dyn PolicyObserver>>,
168 {
169 Self {
170 observers: observers.into_iter().collect(),
171 }
172 }
173
174 pub fn push(&mut self, observer: Arc<dyn PolicyObserver>) {
176 self.observers.push(observer);
177 }
178}
179
180impl PolicyObserver for CompositePolicyObserver {
181 fn on_decision(&self, request: &PolicyRequest, decision: &PolicyDecision, subject: &str) {
182 for observer in &self.observers {
183 observer.on_decision(request, decision, subject);
184 }
185 }
186}
187
188pub struct MxpAuditObserver {
190 emitter: Arc<dyn AuditEmitter>,
191}
192
193impl MxpAuditObserver {
194 #[must_use]
196 pub fn new(emitter: Arc<dyn AuditEmitter>) -> Self {
197 Self { emitter }
198 }
199}
200
201impl PolicyObserver for MxpAuditObserver {
202 fn on_decision(&self, request: &PolicyRequest, decision: &PolicyDecision, subject: &str) {
203 if matches!(decision.kind(), DecisionKind::Deny | DecisionKind::Escalate) {
204 let payload = json!({
205 "agent_id": request.agent_id().to_string(),
206 "subject": subject,
207 "decision": format!("{:?}", decision.kind()),
208 "reason": decision.reason(),
209 "approvers": decision.required_approvals(),
210 "metadata": request.context().metadata(),
211 });
212 let payload_string = payload.to_string();
213 let message = Message::new(MessageType::Event, payload_string.as_bytes());
214 self.emitter.emit(message);
215 }
216 }
217}
218
219#[derive(Clone)]
222pub struct CallExecutor {
223 adapter: Arc<dyn ModelAdapter>,
224 tools: Arc<ToolRegistry>,
225 policy: Option<Arc<dyn PolicyEngine>>,
226 policy_observer: Option<Arc<dyn PolicyObserver>>,
227}
228
229impl fmt::Debug for CallExecutor {
230 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
231 let metadata = self.adapter.metadata();
232 f.debug_struct("CallExecutor")
233 .field("provider", &metadata.provider())
234 .field("model", &metadata.model())
235 .field("policy_configured", &self.policy.is_some())
236 .field("observer_configured", &self.policy_observer.is_some())
237 .finish_non_exhaustive()
238 }
239}
240
241impl CallExecutor {
242 #[must_use]
244 pub fn new(adapter: Arc<dyn ModelAdapter>, tools: Arc<ToolRegistry>) -> Self {
245 Self {
246 adapter,
247 tools,
248 policy: None,
249 policy_observer: None,
250 }
251 }
252
253 pub fn set_policy(&mut self, policy: Arc<dyn PolicyEngine>) {
255 self.policy = Some(policy);
256 }
257
258 #[must_use]
260 pub fn with_policy(mut self, policy: Arc<dyn PolicyEngine>) -> Self {
261 self.set_policy(policy);
262 self
263 }
264
265 #[must_use]
267 pub fn policy(&self) -> Option<&Arc<dyn PolicyEngine>> {
268 self.policy.as_ref()
269 }
270
271 pub fn set_policy_observer(&mut self, observer: Arc<dyn PolicyObserver>) {
273 self.policy_observer = Some(observer);
274 }
275
276 #[must_use]
278 pub fn with_policy_observer(mut self, observer: Arc<dyn PolicyObserver>) -> Self {
279 self.set_policy_observer(observer);
280 self
281 }
282
283 #[must_use]
285 pub fn policy_observer(&self) -> Option<&Arc<dyn PolicyObserver>> {
286 self.policy_observer.as_ref()
287 }
288
289 fn notify_policy(&self, request: &PolicyRequest, decision: &PolicyDecision, subject: &str) {
290 if let Some(observer) = &self.policy_observer {
291 observer.on_decision(request, decision, subject);
292 }
293 }
294
295 async fn enforce_tool_policy(
296 &self,
297 ctx: &HandlerContext,
298 invocation: &ToolInvocation,
299 ) -> HandlerResult<()> {
300 let Some(policy) = self.policy.as_ref() else {
301 return Ok(());
302 };
303
304 let mut request = PolicyRequest::new(
305 ctx.agent_id(),
306 PolicyAction::InvokeTool {
307 name: invocation.name.clone(),
308 },
309 );
310
311 request
312 .context_mut()
313 .insert_metadata("input", invocation.input.clone());
314
315 if let Some(handle) = self.tools.get(&invocation.name) {
316 let metadata = handle.metadata().clone();
317 request
318 .context_mut()
319 .insert_metadata("tool_version", Value::from(metadata.version().to_owned()));
320 if let Some(description) = metadata.description() {
321 request
322 .context_mut()
323 .insert_metadata("tool_description", Value::from(description.to_owned()));
324 }
325
326 if !metadata.capabilities().is_empty() {
327 let capabilities: Vec<String> = metadata
328 .capabilities()
329 .iter()
330 .map(|cap| cap.as_str().to_owned())
331 .collect();
332 request
333 .context_mut()
334 .insert_metadata("capabilities", Value::from(capabilities.clone()));
335 request
336 .context_mut()
337 .extend_tags(capabilities.iter().map(|cap| format!("cap:{cap}")));
338 }
339 }
340
341 let decision = policy
342 .evaluate(&request)
343 .await
344 .map_err(|err| map_policy_error(&err))?;
345
346 self.notify_policy(&request, &decision, &request.action().label());
347 enforce_decision(&decision, &request.action().label())
348 }
349
350 async fn enforce_inference_policy(
351 &self,
352 ctx: &HandlerContext,
353 message_count: usize,
354 tool_names: &[String],
355 ) -> HandlerResult<()> {
356 let Some(policy) = self.policy.as_ref() else {
357 return Ok(());
358 };
359
360 let metadata = self.adapter.metadata();
361 let mut request = PolicyRequest::new(
362 ctx.agent_id(),
363 PolicyAction::ModelInference {
364 provider: metadata.provider().to_owned(),
365 model: metadata.model().to_owned(),
366 },
367 );
368
369 request
370 .context_mut()
371 .insert_metadata("message_count", Value::from(message_count as u64));
372
373 if !tool_names.is_empty() {
374 request
375 .context_mut()
376 .insert_metadata("tools", Value::from(tool_names.to_owned()));
377 request
378 .context_mut()
379 .extend_tags(tool_names.iter().map(|name| format!("tool:{name}")));
380 }
381
382 let decision = policy
383 .evaluate(&request)
384 .await
385 .map_err(|err| map_policy_error(&err))?;
386
387 self.notify_policy(&request, &decision, &request.action().label());
388 enforce_decision(&decision, &request.action().label())
389 }
390
391 pub async fn execute(&self, ctx: &HandlerContext) -> HandlerResult<CallOutcome> {
398 let payload = parse_payload(ctx)?;
399
400 let mut messages = payload.messages;
401 let mut tool_names = Vec::new();
402 let mut tool_results = Vec::new();
403
404 for invocation in payload.tools {
405 self.enforce_tool_policy(ctx, &invocation).await?;
406
407 let tool_output = self
408 .tools
409 .invoke(&invocation.name, invocation.input.clone())
410 .await
411 .map_err(|err| map_tool_error(&invocation.name, &err))?;
412
413 let message_content =
414 serde_json::to_string(&tool_output).unwrap_or_else(|_| String::new());
415 messages.push(PromptMessage::new(MessageRole::Tool, message_content));
416 tool_names.push(invocation.name.clone());
417 tool_results.push(ToolInvocationResult {
418 name: invocation.name,
419 output: tool_output,
420 });
421 }
422
423 self.enforce_inference_policy(ctx, messages.len(), &tool_names)
424 .await?;
425
426 let mut request = InferenceRequest::new(messages)
427 .map_err(|err| HandlerError::custom(format!("invalid request: {err}")))?;
428
429 if let Some(max_tokens) = payload.max_output_tokens {
430 request = request.with_max_output_tokens(max_tokens);
431 }
432
433 if let Some(temperature) = payload.temperature {
434 request = request.with_temperature(temperature);
435 }
436
437 if !tool_names.is_empty() {
438 request = request.with_tools(tool_names);
439 }
440
441 let mut stream = self
442 .adapter
443 .infer(request)
444 .await
445 .map_err(|err| map_adapter_error(&err, self.adapter.metadata()))?;
446
447 let mut response = String::new();
448 while let Some(chunk) = stream.next().await {
449 let chunk = chunk.map_err(|err| map_adapter_error(&err, self.adapter.metadata()))?;
450 response.push_str(&chunk.delta);
451 if chunk.done {
452 break;
453 }
454 }
455
456 Ok(CallOutcome {
457 response,
458 tool_results,
459 })
460 }
461}
462
463fn parse_payload(ctx: &HandlerContext) -> HandlerResult<CallPayload> {
464 let payload = ctx.message().payload();
465 if payload.is_empty() {
466 return Err(HandlerError::custom("call payload missing"));
467 }
468
469 serde_json::from_slice::<CallPayload>(payload.as_ref())
470 .map_err(|err| HandlerError::custom(format!("failed to decode call payload: {err}")))
471}
472
473fn map_tool_error(name: &str, err: &ToolError) -> HandlerError {
474 HandlerError::custom(format!("tool `{name}` failed: {err}"))
475}
476
477fn map_adapter_error(
478 err: &AdapterError,
479 metadata: &agent_adapters::traits::AdapterMetadata,
480) -> HandlerError {
481 HandlerError::custom(format!(
482 "adapter `{}` for model `{}` error: {err}",
483 metadata.provider(),
484 metadata.model()
485 ))
486}
487
488fn map_memory_error(err: &MemoryError) -> HandlerError {
489 HandlerError::custom(format!("memory error: {err}"))
490}
491
492fn map_policy_error(err: &PolicyError) -> HandlerError {
493 HandlerError::custom(format!("policy engine error: {err}"))
494}
495
496fn enforce_decision(decision: &PolicyDecision, subject: &str) -> HandlerResult<()> {
497 match decision.kind() {
498 DecisionKind::Allow => Ok(()),
499 DecisionKind::Deny => {
500 let reason = decision.reason().unwrap_or("policy denied the request");
501 Err(HandlerError::custom(format!(
502 "policy denied {subject}: {reason}"
503 )))
504 }
505 DecisionKind::Escalate => {
506 let reason = decision.reason().unwrap_or("policy escalation required");
507 let approvers = decision.required_approvals();
508 let detail = if approvers.is_empty() {
509 reason.to_owned()
510 } else {
511 format!("{reason} (approvers: {})", approvers.join(", "))
512 };
513 Err(HandlerError::custom(format!(
514 "policy escalation required for {subject}: {detail}"
515 )))
516 }
517 }
518}
519
520#[derive(Debug)]
522pub struct CallOutcome {
523 response: String,
524 tool_results: Vec<ToolInvocationResult>,
525}
526
527impl CallOutcome {
528 #[must_use]
530 pub fn response(&self) -> &str {
531 &self.response
532 }
533
534 #[must_use]
536 pub fn tool_results(&self) -> &[ToolInvocationResult] {
537 &self.tool_results
538 }
539}
540
541#[derive(Debug, Clone)]
543pub struct ToolInvocationResult {
544 pub name: String,
546 pub output: Value,
548}
549
550#[derive(Debug, Deserialize)]
551struct CallPayload {
552 messages: Vec<PromptMessage>,
553 #[serde(default)]
554 temperature: Option<f32>,
555 #[serde(default)]
556 max_output_tokens: Option<u32>,
557 #[serde(default)]
558 tools: Vec<ToolInvocation>,
559}
560
561#[derive(Debug, Deserialize)]
562struct ToolInvocation {
563 name: String,
564 #[serde(default)]
565 input: Value,
566}
567
568pub struct KernelMessageHandler {
570 executor: Arc<CallExecutor>,
571 sink: Arc<dyn CallOutcomeSink>,
572 memory: Option<Arc<MemoryBus>>,
573}
574
575impl KernelMessageHandler {
576 #[must_use]
578 pub fn new(
579 adapter: Arc<dyn ModelAdapter>,
580 tools: Arc<ToolRegistry>,
581 sink: Arc<dyn CallOutcomeSink>,
582 ) -> Self {
583 let executor = Arc::new(CallExecutor::new(adapter, tools));
584 Self {
585 executor,
586 sink,
587 memory: None,
588 }
589 }
590
591 #[must_use]
593 pub fn builder(
594 adapter: Arc<dyn ModelAdapter>,
595 sink: Arc<dyn CallOutcomeSink>,
596 ) -> KernelMessageHandlerBuilder {
597 KernelMessageHandlerBuilder::new(adapter, sink)
598 }
599
600 #[must_use]
602 pub fn with_memory(mut self, memory: Arc<MemoryBus>) -> Self {
603 self.memory = Some(memory);
604 self
605 }
606
607 pub fn set_memory(&mut self, memory: Arc<MemoryBus>) {
609 self.memory = Some(memory);
610 }
611
612 #[must_use]
614 pub fn with_policy(mut self, policy: Arc<dyn PolicyEngine>) -> Self {
615 self.set_policy(policy);
616 self
617 }
618
619 pub fn set_policy(&mut self, policy: Arc<dyn PolicyEngine>) {
621 Arc::make_mut(&mut self.executor).set_policy(policy);
622 }
623
624 #[must_use]
626 pub fn with_policy_observer(mut self, observer: Arc<dyn PolicyObserver>) -> Self {
627 self.set_policy_observer(observer);
628 self
629 }
630
631 pub fn set_policy_observer(&mut self, observer: Arc<dyn PolicyObserver>) {
633 Arc::make_mut(&mut self.executor).set_policy_observer(observer);
634 }
635
636 #[must_use]
638 pub fn policy_observer(&self) -> Option<&Arc<dyn PolicyObserver>> {
639 self.executor.policy_observer()
640 }
641
642 #[must_use]
644 pub fn memory(&self) -> Option<&Arc<MemoryBus>> {
645 self.memory.as_ref()
646 }
647
648 async fn record_inbound(&self, ctx: &HandlerContext) -> HandlerResult<()> {
649 let Some(memory) = &self.memory else {
650 return Ok(());
651 };
652
653 let record = MemoryRecord::builder(MemoryChannel::Input, ctx.message().payload().clone())
654 .tag("mxp.call")
655 .map_err(|err| map_memory_error(&err))?
656 .metadata("direction", Value::from("inbound"))
657 .metadata("message_type", Value::from("call"))
658 .metadata("agent_id", Value::from(ctx.agent_id().to_string()))
659 .build()
660 .map_err(|err| map_memory_error(&err))?;
661
662 self.enforce_memory_policy(ctx.agent_id(), &record).await?;
663 memory
664 .record(record)
665 .await
666 .map_err(|err| map_memory_error(&err))?;
667 Ok(())
668 }
669
670 async fn record_outbound(&self, agent_id: AgentId, outcome: &CallOutcome) -> HandlerResult<()> {
671 let Some(memory) = &self.memory else {
672 return Ok(());
673 };
674
675 for tool in outcome.tool_results() {
676 let payload = Bytes::from(serde_json::to_vec(&tool.output).map_err(|err| {
677 HandlerError::custom(format!("failed to encode tool output: {err}"))
678 })?);
679 let record = MemoryRecord::builder(MemoryChannel::Tool, payload)
680 .tag("mxp.call")
681 .map_err(|err| map_memory_error(&err))?
682 .tag("tool")
683 .map_err(|err| map_memory_error(&err))?
684 .metadata("direction", Value::from("tool"))
685 .metadata("tool_name", Value::from(tool.name.clone()))
686 .build()
687 .map_err(|err| map_memory_error(&err))?;
688 self.enforce_memory_policy(agent_id, &record).await?;
689 memory
690 .record(record)
691 .await
692 .map_err(|err| map_memory_error(&err))?;
693 }
694
695 let response_record = MemoryRecord::builder(
696 MemoryChannel::Output,
697 Bytes::from(outcome.response().to_owned()),
698 )
699 .tag("mxp.call")
700 .map_err(|err| map_memory_error(&err))?
701 .metadata("direction", Value::from("outbound"))
702 .metadata("message_type", Value::from("call"))
703 .build()
704 .map_err(|err| map_memory_error(&err))?;
705
706 self.enforce_memory_policy(agent_id, &response_record)
707 .await?;
708 memory
709 .record(response_record)
710 .await
711 .map_err(|err| map_memory_error(&err))?;
712 Ok(())
713 }
714
715 async fn enforce_memory_policy(
716 &self,
717 agent_id: AgentId,
718 record: &MemoryRecord,
719 ) -> HandlerResult<()> {
720 let Some(policy) = self.executor.policy() else {
721 return Ok(());
722 };
723
724 let request = PolicyRequest::from_memory_record(agent_id, record);
725 let decision = policy
726 .evaluate(&request)
727 .await
728 .map_err(|err| map_policy_error(&err))?;
729
730 self.executor
731 .notify_policy(&request, &decision, &request.action().label());
732 enforce_decision(&decision, &request.action().label())
733 }
734
735 #[must_use]
737 pub fn executor(&self) -> &CallExecutor {
738 &self.executor
739 }
740}
741
742pub struct KernelMessageHandlerBuilder {
744 adapter: Arc<dyn ModelAdapter>,
745 sink: Arc<dyn CallOutcomeSink>,
746 tools: Vec<ToolBinding>,
747 memory: Option<Arc<MemoryBus>>,
748 policy: Option<Arc<dyn PolicyEngine>>,
749 policy_observer: Option<Arc<dyn PolicyObserver>>,
750}
751
752impl KernelMessageHandlerBuilder {
753 fn new(adapter: Arc<dyn ModelAdapter>, sink: Arc<dyn CallOutcomeSink>) -> Self {
754 Self {
755 adapter,
756 sink,
757 tools: Vec::new(),
758 memory: None,
759 policy: None,
760 policy_observer: None,
761 }
762 }
763
764 pub fn with_tools<F, I>(mut self, tools: I) -> Result<Self, ToolError>
771 where
772 F: Copy + 'static,
773 I: IntoIterator<Item = F>,
774 {
775 for tool in tools {
776 let type_name = std::any::type_name_of_val(&tool);
777 let descriptor = descriptor_from_type_name(type_name);
778 self.tools.push(descriptor.binding()?);
779 }
780 Ok(self)
781 }
782
783 #[must_use]
785 pub fn with_memory(mut self, memory: Arc<MemoryBus>) -> Self {
786 self.memory = Some(memory);
787 self
788 }
789
790 #[must_use]
792 pub fn with_policy(mut self, policy: Arc<dyn PolicyEngine>) -> Self {
793 self.policy = Some(policy);
794 self
795 }
796
797 #[must_use]
799 pub fn with_policy_observer(mut self, observer: Arc<dyn PolicyObserver>) -> Self {
800 self.policy_observer = Some(observer);
801 self
802 }
803
804 pub fn build(self) -> Result<KernelMessageHandler, ToolError> {
810 let registry = Arc::new(ToolRegistry::new());
811 for binding in self.tools {
812 registry.register_binding(binding)?;
813 }
814
815 let mut handler = KernelMessageHandler::new(self.adapter, registry, self.sink);
816
817 if let Some(memory) = self.memory {
818 handler.set_memory(memory);
819 }
820 if let Some(policy) = self.policy {
821 handler.set_policy(policy);
822 }
823 if let Some(observer) = self.policy_observer {
824 handler.set_policy_observer(observer);
825 }
826
827 Ok(handler)
828 }
829}
830
831#[async_trait]
832impl crate::AgentMessageHandler for KernelMessageHandler {
833 async fn handle_call(&self, ctx: HandlerContext) -> HandlerResult {
834 self.record_inbound(&ctx).await?;
835
836 let outcome = self.executor.execute(&ctx).await?;
837
838 self.record_outbound(ctx.agent_id(), &outcome).await?;
839
840 self.sink.record(outcome);
841 Ok(())
842 }
843}
844
845pub trait CallOutcomeSink: Send + Sync {
847 fn record(&self, outcome: CallOutcome);
849}
850
851#[derive(Default)]
853pub struct TracingCallSink;
854
855impl CallOutcomeSink for TracingCallSink {
856 fn record(&self, outcome: CallOutcome) {
857 let tool_names: Vec<String> = outcome
858 .tool_results()
859 .iter()
860 .map(|result| result.name.clone())
861 .collect();
862 tracing::info!(
863 response = outcome.response(),
864 tools = ?tool_names,
865 "call execution completed"
866 );
867 }
868}
869
870#[derive(Default)]
872pub struct CollectingSink {
873 results: Mutex<Vec<CallOutcome>>,
874}
875
876impl CollectingSink {
877 #[must_use]
879 pub fn new() -> Arc<Self> {
880 Arc::new(Self {
881 results: Mutex::new(Vec::new()),
882 })
883 }
884
885 #[must_use]
891 pub fn drain(&self) -> Vec<CallOutcome> {
892 let mut lock = self.results.lock().expect("collecting sink poisoned");
893 lock.drain(..).collect()
894 }
895}
896
897impl CallOutcomeSink for CollectingSink {
898 fn record(&self, outcome: CallOutcome) {
899 self.results
900 .lock()
901 .expect("collecting sink poisoned")
902 .push(outcome);
903 }
904}
905
906#[cfg(test)]
907mod tests {
908 use super::*;
909
910 use agent_adapters::traits::{AdapterMetadata, AdapterResult, AdapterStream, InferenceChunk};
911 use agent_memory::{FileJournal, MemoryBusBuilder, MemoryChannel, VolatileConfig};
912 use agent_policy::{PolicyAction, PolicyDecision, PolicyEngine, PolicyRequest, PolicyResult};
913 use agent_primitives::AgentId;
914 use agent_tools::registry::{ToolMetadata, ToolRegistry};
915 use futures::stream;
916 use mxp::Message;
917 use serde_json::json;
918 use std::io::ErrorKind;
919 use std::net::SocketAddr;
920 use std::num::NonZeroUsize;
921 use std::sync::{Arc, Mutex};
922 use std::time::Duration;
923 use tokio::sync::oneshot;
924
925 use crate::{AgentMessageHandler, HandlerContext, HandlerError};
926
927 struct StaticAdapter {
928 metadata: AdapterMetadata,
929 response: String,
930 }
931
932 #[async_trait]
933 impl ModelAdapter for StaticAdapter {
934 fn metadata(&self) -> &AdapterMetadata {
935 &self.metadata
936 }
937
938 async fn infer(&self, _request: InferenceRequest) -> AdapterResult<AdapterStream> {
939 let chunk = InferenceChunk::new(self.response.clone(), true);
940 Ok(Box::pin(stream::once(async move { Ok(chunk) })))
941 }
942 }
943
944 struct DenyPolicy;
945
946 #[async_trait]
947 impl PolicyEngine for DenyPolicy {
948 async fn evaluate(&self, request: &PolicyRequest) -> PolicyResult<PolicyDecision> {
949 match request.action() {
950 PolicyAction::InvokeTool { .. } => Ok(PolicyDecision::deny("disabled by policy")),
951 _ => Ok(PolicyDecision::allow()),
952 }
953 }
954 }
955
956 fn temp_path() -> std::path::PathBuf {
957 let mut path = std::env::temp_dir();
958 path.push(format!("handler-test-{}.log", AgentId::random()));
959 path
960 }
961
962 #[tokio::test]
963 async fn executes_call_pipeline() {
964 let adapter = Arc::new(StaticAdapter {
965 metadata: AdapterMetadata::new("test", "static"),
966 response: "static-response".to_owned(),
967 });
968 let tools = Arc::new(ToolRegistry::new());
969 tools
970 .register_tool(
971 ToolMetadata::new("echo", "1.0.0").unwrap(),
972 |input: Value| async move { Ok(input) },
973 )
974 .unwrap();
975
976 let sink = CollectingSink::new();
977 let handler = KernelMessageHandler::new(adapter, tools, sink.clone());
978
979 let payload = json!({
980 "messages": [
981 {"role": "system", "content": "You are helpful."},
982 {"role": "user", "content": "Ping"}
983 ],
984 "tools": [
985 {"name": "echo", "input": {"value": 1}}
986 ]
987 });
988
989 let message = mxp::Message::new(mxp::MessageType::Call, payload.to_string().as_bytes());
990 let ctx = HandlerContext::from_message(agent_primitives::AgentId::random(), message);
991
992 handler.handle_call(ctx).await.unwrap();
993
994 let results = sink.drain();
995 assert_eq!(results.len(), 1);
996 assert_eq!(results[0].response(), "static-response");
997 assert_eq!(results[0].tool_results().len(), 1);
998 }
999
1000 #[tokio::test]
1001 async fn policy_denies_tool_invocation() {
1002 let adapter = Arc::new(StaticAdapter {
1003 metadata: AdapterMetadata::new("test", "static"),
1004 response: "static-response".to_owned(),
1005 });
1006 let tools = Arc::new(ToolRegistry::new());
1007 tools
1008 .register_tool(
1009 ToolMetadata::new("echo", "1.0.0").unwrap(),
1010 |input: Value| async move { Ok(input) },
1011 )
1012 .unwrap();
1013
1014 let sink = CollectingSink::new();
1015 let handler = KernelMessageHandler::new(adapter, tools, sink.clone())
1016 .with_policy(Arc::new(DenyPolicy));
1017
1018 let payload = json!({
1019 "messages": [
1020 {"role": "user", "content": "ping"}
1021 ],
1022 "tools": [
1023 {"name": "echo", "input": {"value": 1}}
1024 ]
1025 });
1026
1027 let message = mxp::Message::new(mxp::MessageType::Call, payload.to_string().as_bytes());
1028 let ctx = HandlerContext::from_message(agent_primitives::AgentId::random(), message);
1029
1030 let err = handler
1031 .handle_call(ctx)
1032 .await
1033 .expect_err("policy should deny");
1034 match err {
1035 HandlerError::Custom(reason) => assert!(reason.contains("policy denied")),
1036 other => panic!("unexpected error: {other:?}"),
1037 }
1038
1039 assert!(sink.drain().is_empty());
1040 }
1041
1042 #[tokio::test]
1043 async fn persists_transcript_via_memory_bus() {
1044 let adapter = Arc::new(StaticAdapter {
1045 metadata: AdapterMetadata::new("test", "static"),
1046 response: "ok".to_owned(),
1047 });
1048
1049 let tools = Arc::new(ToolRegistry::new());
1050 tools
1051 .register_tool(
1052 ToolMetadata::new("echo", "1.0.0").unwrap(),
1053 |input: Value| async move { Ok(input) },
1054 )
1055 .unwrap();
1056
1057 let sink = CollectingSink::new();
1058 let path = temp_path();
1059 let journal: Arc<dyn agent_memory::Journal> =
1060 Arc::new(FileJournal::open(&path).await.unwrap());
1061 let bus = Arc::new(
1062 MemoryBusBuilder::new(VolatileConfig::new(NonZeroUsize::new(8).unwrap()))
1063 .with_journal(journal.clone())
1064 .build()
1065 .unwrap(),
1066 );
1067
1068 let handler =
1069 KernelMessageHandler::new(adapter, tools, sink.clone()).with_memory(bus.clone());
1070
1071 let payload = json!({
1072 "messages": [
1073 {"role": "user", "content": "hello"}
1074 ],
1075 "tools": [
1076 {"name": "echo", "input": {"value": 1}}
1077 ]
1078 });
1079
1080 let message = mxp::Message::new(mxp::MessageType::Call, payload.to_string().as_bytes());
1081 let ctx = HandlerContext::from_message(agent_primitives::AgentId::random(), message);
1082
1083 handler.handle_call(ctx).await.unwrap();
1084
1085 let records = bus.recent(5).await;
1086 assert_eq!(records.len(), 3);
1087 assert!(matches!(records[0].channel(), MemoryChannel::Input));
1088 assert!(matches!(records[1].channel(), MemoryChannel::Tool));
1089 assert!(matches!(records[2].channel(), MemoryChannel::Output));
1090
1091 if path.exists() {
1092 let _ = std::fs::remove_file(path);
1093 }
1094 }
1095
1096 struct RecordingObserver {
1097 decisions: Mutex<Vec<(String, DecisionKind)>>,
1098 }
1099
1100 impl RecordingObserver {
1101 fn new() -> Arc<Self> {
1102 Arc::new(Self {
1103 decisions: Mutex::new(Vec::new()),
1104 })
1105 }
1106 }
1107
1108 impl PolicyObserver for RecordingObserver {
1109 fn on_decision(&self, request: &PolicyRequest, decision: &PolicyDecision, subject: &str) {
1110 let mut guard = self.decisions.lock().expect("observer poisoned");
1111 guard.push((
1112 format!("{}:{}", request.agent_id(), subject),
1113 decision.kind(),
1114 ));
1115 }
1116 }
1117
1118 #[tokio::test]
1119 async fn observer_receives_decisions() {
1120 let adapter = Arc::new(StaticAdapter {
1121 metadata: AdapterMetadata::new("test", "static"),
1122 response: "ok".to_owned(),
1123 });
1124 let tools = Arc::new(ToolRegistry::new());
1125 tools
1126 .register_tool(
1127 ToolMetadata::new("echo", "1.0.0").unwrap(),
1128 |input: Value| async move { Ok(input) },
1129 )
1130 .unwrap();
1131
1132 let sink = CollectingSink::new();
1133 let observer = RecordingObserver::new();
1134 let handler = KernelMessageHandler::new(adapter, tools, sink.clone())
1135 .with_policy(Arc::new(DenyPolicy))
1136 .with_policy_observer(observer.clone());
1137
1138 let payload = json!({
1139 "messages": [
1140 {"role": "user", "content": "ping"}
1141 ],
1142 "tools": [
1143 {"name": "echo", "input": {"value": 1}}
1144 ]
1145 });
1146
1147 let message = mxp::Message::new(mxp::MessageType::Call, payload.to_string().as_bytes());
1148 let ctx = HandlerContext::from_message(agent_primitives::AgentId::random(), message);
1149
1150 handler
1151 .handle_call(ctx)
1152 .await
1153 .expect_err("policy should deny");
1154
1155 let records = observer
1156 .decisions
1157 .lock()
1158 .expect("observer poisoned")
1159 .clone();
1160 assert_eq!(records.len(), 1);
1161 assert_eq!(records[0].1, DecisionKind::Deny);
1162 }
1163
1164 struct MemoryDenyPolicy;
1165
1166 #[async_trait]
1167 impl PolicyEngine for MemoryDenyPolicy {
1168 async fn evaluate(&self, request: &PolicyRequest) -> PolicyResult<PolicyDecision> {
1169 match request.action() {
1170 PolicyAction::EmitEvent { event_type } if event_type == "memory_record" => {
1171 Ok(PolicyDecision::deny("memory recording disabled"))
1172 }
1173 _ => Ok(PolicyDecision::allow()),
1174 }
1175 }
1176 }
1177
1178 #[tokio::test]
1179 async fn policy_denies_memory_recording() {
1180 let adapter = Arc::new(StaticAdapter {
1181 metadata: AdapterMetadata::new("test", "static"),
1182 response: "ok".to_owned(),
1183 });
1184 let tools = Arc::new(ToolRegistry::new());
1185 tools
1186 .register_tool(
1187 ToolMetadata::new("echo", "1.0.0").unwrap(),
1188 |input: Value| async move { Ok(input) },
1189 )
1190 .unwrap();
1191
1192 let sink = CollectingSink::new();
1193 let journal_path = temp_path();
1194 let journal: Arc<dyn agent_memory::Journal> =
1195 Arc::new(FileJournal::open(&journal_path).await.expect("journal"));
1196 let memory_bus = Arc::new(
1197 MemoryBusBuilder::new(VolatileConfig::default())
1198 .with_journal(journal)
1199 .build()
1200 .expect("bus"),
1201 );
1202
1203 let handler = KernelMessageHandler::new(adapter, tools, sink)
1204 .with_memory(memory_bus)
1205 .with_policy(Arc::new(MemoryDenyPolicy));
1206
1207 let payload = json!({
1208 "messages": [
1209 {"role": "user", "content": "ping"}
1210 ],
1211 "tools": [
1212 {"name": "echo", "input": {"value": 1}}
1213 ]
1214 });
1215
1216 let message = mxp::Message::new(mxp::MessageType::Call, payload.to_string().as_bytes());
1217 let ctx = HandlerContext::from_message(agent_primitives::AgentId::random(), message);
1218
1219 let err = handler
1220 .handle_call(ctx)
1221 .await
1222 .expect_err("policy should deny");
1223 match err {
1224 HandlerError::Custom(reason) => assert!(reason.contains("policy denied")),
1225 other => panic!("unexpected error: {other:?}"),
1226 }
1227
1228 if journal_path.exists() {
1229 let _ = std::fs::remove_file(&journal_path);
1230 }
1231 }
1232
1233 #[test]
1234 fn composite_audit_emitter_fans_out() {
1235 let emitter_a = RecordingAuditEmitter::new();
1236 let emitter_b = RecordingAuditEmitter::new();
1237
1238 let composite = CompositeAuditEmitter::new([
1239 Arc::clone(&emitter_a) as Arc<dyn AuditEmitter>,
1240 Arc::clone(&emitter_b) as Arc<dyn AuditEmitter>,
1241 ]);
1242
1243 let message = Message::new(MessageType::Event, b"composite");
1244 composite.emit(message);
1245
1246 let events_a = emitter_a.events.lock().expect("emitter A poisoned");
1247 assert_eq!(events_a.len(), 1);
1248 let events_b = emitter_b.events.lock().expect("emitter B poisoned");
1249 assert_eq!(events_b.len(), 1);
1250 }
1251
1252 #[tokio::test]
1253 async fn governance_emitter_transmits_over_mxp() {
1254 let transport = mxp::Transport::default();
1255 let receiver = match transport
1256 .bind("127.0.0.1:0".parse::<SocketAddr>().expect("receiver addr"))
1257 {
1258 Ok(handle) => handle,
1259 Err(mxp::transport::SocketError::Io(err))
1260 if err.kind() == ErrorKind::PermissionDenied =>
1261 {
1262 eprintln!(
1263 "skipping governance emitter test: receiver bind requires elevated privileges"
1264 );
1265 return;
1266 }
1267 Err(err) => panic!("bind receiver: {err:?}"),
1268 };
1269 let receiver_addr = receiver.local_addr().expect("receiver local addr");
1270 let sender = match transport.bind("127.0.0.1:0".parse::<SocketAddr>().expect("sender addr"))
1271 {
1272 Ok(handle) => handle,
1273 Err(mxp::transport::SocketError::Io(err))
1274 if err.kind() == ErrorKind::PermissionDenied =>
1275 {
1276 eprintln!(
1277 "skipping governance emitter test: sender bind requires elevated privileges"
1278 );
1279 return;
1280 }
1281 Err(err) => panic!("bind sender: {err:?}"),
1282 };
1283
1284 let emitter = GovernanceAuditEmitter::new(sender, receiver_addr);
1285 let message = Message::new(MessageType::Event, b"{\"audit\":true}");
1286
1287 emitter.emit(message.clone());
1288
1289 let (tx, rx) = oneshot::channel();
1290 let recv_handle = receiver.clone();
1291 let join = tokio::task::spawn_blocking(move || {
1292 let mut buffer = recv_handle.acquire_buffer();
1293 let result = recv_handle.receive(&mut buffer);
1294 tx.send(result.map(|(len, _)| buffer.as_slice()[..len].to_vec()))
1295 .ok();
1296 });
1297
1298 let recv_result = tokio::time::timeout(Duration::from_millis(500), rx)
1299 .await
1300 .expect("timed out waiting for audit event")
1301 .expect("audit receiver task cancelled");
1302 let bytes: Vec<u8> = match recv_result {
1303 Ok(buf) => buf,
1304 Err(err) => panic!("receiver failed to capture event: {err:?}"),
1305 };
1306
1307 join.await.expect("blocking receive failed");
1308
1309 let decoded = Message::decode(bytes).expect("decoded message");
1310 assert_eq!(decoded.message_type(), Some(MessageType::Event));
1311 assert_eq!(decoded.payload(), message.payload());
1312 }
1313
1314 struct RecordingAuditEmitter {
1315 events: Mutex<Vec<Message>>,
1316 }
1317
1318 impl RecordingAuditEmitter {
1319 fn new() -> Arc<Self> {
1320 Arc::new(Self {
1321 events: Mutex::new(Vec::new()),
1322 })
1323 }
1324 }
1325
1326 impl AuditEmitter for RecordingAuditEmitter {
1327 fn emit(&self, message: Message) {
1328 self.events.lock().expect("emitter poisoned").push(message);
1329 }
1330 }
1331
1332 struct EscalatePolicy;
1333
1334 #[async_trait]
1335 impl PolicyEngine for EscalatePolicy {
1336 async fn evaluate(&self, _request: &PolicyRequest) -> PolicyResult<PolicyDecision> {
1337 Ok(PolicyDecision::escalate(
1338 "needs approval",
1339 vec!["secops".into()],
1340 ))
1341 }
1342 }
1343
1344 #[tokio::test]
1345 async fn audit_observer_emits_event_on_escalation() {
1346 let adapter = Arc::new(StaticAdapter {
1347 metadata: AdapterMetadata::new("test", "static"),
1348 response: "ok".to_owned(),
1349 });
1350 let tools = Arc::new(ToolRegistry::new());
1351 tools
1352 .register_tool(
1353 ToolMetadata::new("echo", "1.0.0").unwrap(),
1354 |input: Value| async move { Ok(input) },
1355 )
1356 .unwrap();
1357
1358 let sink = CollectingSink::new();
1359 let emitter = RecordingAuditEmitter::new();
1360 let observer = CompositePolicyObserver::new([
1361 Arc::new(TracingPolicyObserver) as Arc<dyn PolicyObserver>,
1362 Arc::new(MxpAuditObserver::new(emitter.clone())) as Arc<dyn PolicyObserver>,
1363 ]);
1364
1365 let handler = KernelMessageHandler::new(adapter, tools, sink)
1366 .with_policy(Arc::new(EscalatePolicy))
1367 .with_policy_observer(Arc::new(observer) as Arc<dyn PolicyObserver>);
1368
1369 let payload = json!({
1370 "messages": [
1371 {"role": "user", "content": "ping"}
1372 ]
1373 });
1374
1375 let message = mxp::Message::new(mxp::MessageType::Call, payload.to_string().as_bytes());
1376 let ctx = HandlerContext::from_message(agent_primitives::AgentId::random(), message);
1377
1378 let err = handler
1379 .handle_call(ctx)
1380 .await
1381 .expect_err("policy should escalate");
1382 match err {
1383 HandlerError::Custom(reason) => assert!(reason.contains("policy escalation")),
1384 other => panic!("unexpected error: {other:?}"),
1385 }
1386
1387 let events = emitter.events.lock().expect("emitter poisoned");
1388 assert_eq!(events.len(), 1);
1389 let payload = String::from_utf8_lossy(events[0].payload());
1390 assert!(payload.contains("needs approval"));
1391 assert!(payload.contains("secops"));
1392 }
1393}