use std::sync::Arc;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use synwire_core::BoxFuture;
use synwire_core::agents::error::AgentError;
use synwire_core::mcp::elicitation::{ElicitationRequest, ElicitationResult, OnElicitation};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum McpLogLevel {
Debug,
Info,
Warning,
Error,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpLoggingMessage {
pub level: McpLogLevel,
pub logger: Option<String>,
pub data: Value,
}
pub trait OnMcpLogging: Send + Sync {
fn on_log(&self, server_name: &str, message: McpLoggingMessage);
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpProgressNotification {
pub progress_token: String,
pub progress: u64,
pub total: Option<u64>,
pub message: Option<String>,
}
pub trait OnMcpProgress: Send + Sync {
fn on_progress(&self, server_name: &str, notification: McpProgressNotification);
}
#[derive(Debug, Default, Clone)]
pub struct DiscardLogging;
impl OnMcpLogging for DiscardLogging {
fn on_log(&self, _server_name: &str, _message: McpLoggingMessage) {}
}
#[derive(Debug, Default, Clone)]
pub struct DiscardProgress;
impl OnMcpProgress for DiscardProgress {
fn on_progress(&self, _server_name: &str, _notification: McpProgressNotification) {}
}
#[derive(Debug, Default, Clone)]
pub struct TracingLogging;
impl OnMcpLogging for TracingLogging {
fn on_log(&self, server_name: &str, message: McpLoggingMessage) {
match message.level {
McpLogLevel::Debug => {
tracing::debug!(server = %server_name, logger = ?message.logger, data = ?message.data, "MCP log");
}
McpLogLevel::Info => {
tracing::info!(server = %server_name, logger = ?message.logger, data = ?message.data, "MCP log");
}
McpLogLevel::Warning => {
tracing::warn!(server = %server_name, logger = ?message.logger, data = ?message.data, "MCP log");
}
McpLogLevel::Error => {
tracing::error!(server = %server_name, logger = ?message.logger, data = ?message.data, "MCP log");
}
}
}
}
pub struct McpCallbacks {
pub logging: Arc<dyn OnMcpLogging>,
pub progress: Arc<dyn OnMcpProgress>,
pub elicitation: Arc<dyn OnElicitation>,
}
impl std::fmt::Debug for McpCallbacks {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("McpCallbacks")
.field("logging", &"<handler>")
.field("progress", &"<handler>")
.field("elicitation", &"<handler>")
.finish()
}
}
impl Default for McpCallbacks {
fn default() -> Self {
Self {
logging: Arc::new(DiscardLogging),
progress: Arc::new(DiscardProgress),
elicitation: Arc::new(CancelAllElicitationsAdapter),
}
}
}
impl McpCallbacks {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_logging(mut self, logging: Arc<dyn OnMcpLogging>) -> Self {
self.logging = logging;
self
}
#[must_use]
pub fn with_progress(mut self, progress: Arc<dyn OnMcpProgress>) -> Self {
self.progress = progress;
self
}
#[must_use]
pub fn with_elicitation(mut self, elicitation: Arc<dyn OnElicitation>) -> Self {
self.elicitation = elicitation;
self
}
}
#[derive(Debug)]
struct CancelAllElicitationsAdapter;
impl OnElicitation for CancelAllElicitationsAdapter {
fn elicit(
&self,
request: ElicitationRequest,
) -> BoxFuture<'_, Result<ElicitationResult, AgentError>> {
Box::pin(async move {
Ok(ElicitationResult::Cancelled {
request_id: request.request_id,
})
})
}
}