modeldriveprotocol-client 2.2.0

Rust client SDK for Model Drive Protocol.
Documentation
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, RwLock};

use tokio::sync::{mpsc, Mutex};
use tokio::task::JoinHandle;
use url::Url;

use crate::error::MdpClientError;
use crate::models::{
    AuthContext, ClientDescriptor, ClientInfo, ClientInfoOverride, EndpointOptions, HttpMethod,
    PromptOptions, SkillOptions,
};
use crate::protocol::{CallClientRequest, ClientToServerMessage, ServerToClientMessage};
use crate::registry::ProcedureRegistry;
use crate::transport::{ClientTransport, HttpLoopClientTransport, WebSocketClientTransport};

pub struct MdpClient {
    inner: Arc<MdpClientInner>,
}

struct MdpClientInner {
    client_info: RwLock<ClientInfo>,
    auth: RwLock<Option<AuthContext>>,
    registry: RwLock<ProcedureRegistry>,
    transport: Mutex<Box<dyn ClientTransport>>,
    receive_task: Mutex<Option<JoinHandle<()>>>,
    connected: AtomicBool,
    registered: AtomicBool,
}

impl MdpClient {
    pub fn new(server_url: impl Into<String>, client: ClientInfo) -> Result<Self, MdpClientError> {
        let server_url = server_url.into();
        let url = Url::parse(&server_url).map_err(|error| MdpClientError::Transport(error.to_string()))?;
        let transport: Box<dyn ClientTransport> = match url.scheme() {
            "ws" | "wss" => Box::new(WebSocketClientTransport::new(server_url, None)),
            "http" | "https" => Box::new(HttpLoopClientTransport::new(server_url, None)),
            other => return Err(MdpClientError::Transport(format!("unsupported protocol `{other}`"))),
        };

        Ok(Self::with_transport(client, transport))
    }

    pub fn with_transport(client: ClientInfo, transport: Box<dyn ClientTransport>) -> Self {
        Self {
            inner: Arc::new(MdpClientInner {
                client_info: RwLock::new(client),
                auth: RwLock::new(None),
                registry: RwLock::new(ProcedureRegistry::default()),
                transport: Mutex::new(transport),
                receive_task: Mutex::new(None),
                connected: AtomicBool::new(false),
                registered: AtomicBool::new(false),
            }),
        }
    }

    pub fn set_auth(&self, auth: Option<AuthContext>) {
        *self.inner.auth.write().unwrap() = auth;
    }

    pub fn describe(&self) -> ClientDescriptor {
        let client_info = self.inner.client_info.read().unwrap().clone();
        self.inner.registry.read().unwrap().describe(&client_info)
    }

    pub fn expose_endpoint<H, Fut>(
        &self,
        path: impl Into<String>,
        method: HttpMethod,
        handler: H,
        options: EndpointOptions,
    ) -> Result<(), MdpClientError>
    where
        H: Send + Sync + 'static + Fn(crate::models::PathRequest, crate::models::PathInvocationContext) -> Fut,
        Fut: std::future::Future<Output = Result<serde_json::Value, MdpClientError>> + Send + 'static,
    {
        self.inner
            .registry
            .write()
            .unwrap()
            .expose_endpoint(path, method, handler, options)
    }

    pub fn expose_skill_markdown(
        &self,
        path: impl Into<String>,
        content: impl Into<String>,
        options: SkillOptions,
    ) -> Result<(), MdpClientError> {
        self.inner
            .registry
            .write()
            .unwrap()
            .expose_skill_markdown(path, content, options)
    }

    pub fn expose_prompt_markdown(
        &self,
        path: impl Into<String>,
        content: impl Into<String>,
        options: PromptOptions,
    ) -> Result<(), MdpClientError> {
        self.inner
            .registry
            .write()
            .unwrap()
            .expose_prompt_markdown(path, content, options)
    }

    pub async fn connect(&self) -> Result<(), MdpClientError> {
        let receiver = {
            let mut transport = self.inner.transport.lock().await;
            transport.connect().await?
        };
        self.inner.connected.store(true, Ordering::SeqCst);
        let inner = self.inner.clone();
        let task = tokio::spawn(async move {
            process_messages(inner, receiver).await;
        });
        *self.inner.receive_task.lock().await = Some(task);
        Ok(())
    }

    pub async fn register(
        &self,
        override_info: Option<ClientInfoOverride>,
    ) -> Result<(), MdpClientError> {
        if !self.inner.connected.load(Ordering::SeqCst) {
            return Err(MdpClientError::NotConnected);
        }

        {
            let current = self.inner.client_info.read().unwrap().clone();
            *self.inner.client_info.write().unwrap() = current.apply_override(override_info);
        }

        let descriptor = self.describe();
        let auth = self.inner.auth.read().unwrap().clone();
        self.send(ClientToServerMessage::RegisterClient {
            client: descriptor,
            auth,
        })
        .await?;
        self.inner.registered.store(true, Ordering::SeqCst);
        Ok(())
    }

    pub async fn sync_catalog(&self) -> Result<(), MdpClientError> {
        if !self.inner.connected.load(Ordering::SeqCst) {
            return Err(MdpClientError::NotConnected);
        }
        if !self.inner.registered.load(Ordering::SeqCst) {
            return Err(MdpClientError::NotRegistered);
        }

        let client_id = self.inner.client_info.read().unwrap().id.clone();
        let paths = self.inner.registry.read().unwrap().describe_paths();
        self.send(ClientToServerMessage::UpdateClientCatalog { client_id, paths })
            .await
    }

    pub async fn disconnect(&self) -> Result<(), MdpClientError> {
        if self.inner.connected.load(Ordering::SeqCst) && self.inner.registered.load(Ordering::SeqCst) {
            let client_id = self.inner.client_info.read().unwrap().id.clone();
            self.send(ClientToServerMessage::UnregisterClient { client_id }).await?;
        }
        self.inner.connected.store(false, Ordering::SeqCst);
        self.inner.registered.store(false, Ordering::SeqCst);
        {
            let mut transport = self.inner.transport.lock().await;
            transport.close().await?;
        }
        if let Some(task) = self.inner.receive_task.lock().await.take() {
            task.abort();
        }
        Ok(())
    }

    async fn send(&self, message: ClientToServerMessage) -> Result<(), MdpClientError> {
        let mut transport = self.inner.transport.lock().await;
        transport.send(message).await
    }
}

async fn process_messages(inner: Arc<MdpClientInner>, mut receiver: mpsc::UnboundedReceiver<ServerToClientMessage>) {
    while let Some(message) = receiver.recv().await {
        match message {
            ServerToClientMessage::Ping { timestamp } => {
                let mut transport = inner.transport.lock().await;
                let _ = transport.send(ClientToServerMessage::Pong { timestamp }).await;
            }
            ServerToClientMessage::Pong { .. } => {}
            ServerToClientMessage::CallClient(message) => {
                let result = handle_invocation(&inner, &message).await;
                let mut transport = inner.transport.lock().await;
                let _ = transport.send(result).await;
            }
        }
    }

    inner.connected.store(false, Ordering::SeqCst);
    inner.registered.store(false, Ordering::SeqCst);
}

async fn handle_invocation(
    inner: &Arc<MdpClientInner>,
    message: &CallClientRequest,
) -> ClientToServerMessage {
    let registry = inner.registry.read().unwrap().clone();
    match registry.invoke(message).await {
        Ok(data) => ClientToServerMessage::CallClientResult {
            request_id: message.request_id.clone(),
            ok: true,
            data: Some(data),
            error: None,
        },
        Err(error) => ClientToServerMessage::CallClientResult {
            request_id: message.request_id.clone(),
            ok: false,
            data: None,
            error: Some(crate::models::SerializedError::handler(error.to_string())),
        },
    }
}