Skip to main content

modeldriveprotocol_client/
client.rs

1use std::sync::atomic::{AtomicBool, Ordering};
2use std::sync::{Arc, RwLock};
3
4use tokio::sync::{mpsc, Mutex};
5use tokio::task::JoinHandle;
6use url::Url;
7
8use crate::error::MdpClientError;
9use crate::models::{
10    AuthContext, ClientDescriptor, ClientInfo, ClientInfoOverride, EndpointOptions, HttpMethod,
11    PromptOptions, SkillOptions,
12};
13use crate::protocol::{CallClientRequest, ClientToServerMessage, ServerToClientMessage};
14use crate::registry::ProcedureRegistry;
15use crate::transport::{ClientTransport, HttpLoopClientTransport, WebSocketClientTransport};
16
17pub struct MdpClient {
18    inner: Arc<MdpClientInner>,
19}
20
21struct MdpClientInner {
22    client_info: RwLock<ClientInfo>,
23    auth: RwLock<Option<AuthContext>>,
24    registry: RwLock<ProcedureRegistry>,
25    transport: Mutex<Box<dyn ClientTransport>>,
26    receive_task: Mutex<Option<JoinHandle<()>>>,
27    connected: AtomicBool,
28    registered: AtomicBool,
29}
30
31impl MdpClient {
32    pub fn new(server_url: impl Into<String>, client: ClientInfo) -> Result<Self, MdpClientError> {
33        let server_url = server_url.into();
34        let url = Url::parse(&server_url).map_err(|error| MdpClientError::Transport(error.to_string()))?;
35        let transport: Box<dyn ClientTransport> = match url.scheme() {
36            "ws" | "wss" => Box::new(WebSocketClientTransport::new(server_url, None)),
37            "http" | "https" => Box::new(HttpLoopClientTransport::new(server_url, None)),
38            other => return Err(MdpClientError::Transport(format!("unsupported protocol `{other}`"))),
39        };
40
41        Ok(Self::with_transport(client, transport))
42    }
43
44    pub fn with_transport(client: ClientInfo, transport: Box<dyn ClientTransport>) -> Self {
45        Self {
46            inner: Arc::new(MdpClientInner {
47                client_info: RwLock::new(client),
48                auth: RwLock::new(None),
49                registry: RwLock::new(ProcedureRegistry::default()),
50                transport: Mutex::new(transport),
51                receive_task: Mutex::new(None),
52                connected: AtomicBool::new(false),
53                registered: AtomicBool::new(false),
54            }),
55        }
56    }
57
58    pub fn set_auth(&self, auth: Option<AuthContext>) {
59        *self.inner.auth.write().unwrap() = auth;
60    }
61
62    pub fn describe(&self) -> ClientDescriptor {
63        let client_info = self.inner.client_info.read().unwrap().clone();
64        self.inner.registry.read().unwrap().describe(&client_info)
65    }
66
67    pub fn expose_endpoint<H, Fut>(
68        &self,
69        path: impl Into<String>,
70        method: HttpMethod,
71        handler: H,
72        options: EndpointOptions,
73    ) -> Result<(), MdpClientError>
74    where
75        H: Send + Sync + 'static + Fn(crate::models::PathRequest, crate::models::PathInvocationContext) -> Fut,
76        Fut: std::future::Future<Output = Result<serde_json::Value, MdpClientError>> + Send + 'static,
77    {
78        self.inner
79            .registry
80            .write()
81            .unwrap()
82            .expose_endpoint(path, method, handler, options)
83    }
84
85    pub fn expose_skill_markdown(
86        &self,
87        path: impl Into<String>,
88        content: impl Into<String>,
89        options: SkillOptions,
90    ) -> Result<(), MdpClientError> {
91        self.inner
92            .registry
93            .write()
94            .unwrap()
95            .expose_skill_markdown(path, content, options)
96    }
97
98    pub fn expose_prompt_markdown(
99        &self,
100        path: impl Into<String>,
101        content: impl Into<String>,
102        options: PromptOptions,
103    ) -> Result<(), MdpClientError> {
104        self.inner
105            .registry
106            .write()
107            .unwrap()
108            .expose_prompt_markdown(path, content, options)
109    }
110
111    pub async fn connect(&self) -> Result<(), MdpClientError> {
112        let receiver = {
113            let mut transport = self.inner.transport.lock().await;
114            transport.connect().await?
115        };
116        self.inner.connected.store(true, Ordering::SeqCst);
117        let inner = self.inner.clone();
118        let task = tokio::spawn(async move {
119            process_messages(inner, receiver).await;
120        });
121        *self.inner.receive_task.lock().await = Some(task);
122        Ok(())
123    }
124
125    pub async fn register(
126        &self,
127        override_info: Option<ClientInfoOverride>,
128    ) -> Result<(), MdpClientError> {
129        if !self.inner.connected.load(Ordering::SeqCst) {
130            return Err(MdpClientError::NotConnected);
131        }
132
133        {
134            let current = self.inner.client_info.read().unwrap().clone();
135            *self.inner.client_info.write().unwrap() = current.apply_override(override_info);
136        }
137
138        let descriptor = self.describe();
139        let auth = self.inner.auth.read().unwrap().clone();
140        self.send(ClientToServerMessage::RegisterClient {
141            client: descriptor,
142            auth,
143        })
144        .await?;
145        self.inner.registered.store(true, Ordering::SeqCst);
146        Ok(())
147    }
148
149    pub async fn sync_catalog(&self) -> Result<(), MdpClientError> {
150        if !self.inner.connected.load(Ordering::SeqCst) {
151            return Err(MdpClientError::NotConnected);
152        }
153        if !self.inner.registered.load(Ordering::SeqCst) {
154            return Err(MdpClientError::NotRegistered);
155        }
156
157        let client_id = self.inner.client_info.read().unwrap().id.clone();
158        let paths = self.inner.registry.read().unwrap().describe_paths();
159        self.send(ClientToServerMessage::UpdateClientCatalog { client_id, paths })
160            .await
161    }
162
163    pub async fn disconnect(&self) -> Result<(), MdpClientError> {
164        if self.inner.connected.load(Ordering::SeqCst) && self.inner.registered.load(Ordering::SeqCst) {
165            let client_id = self.inner.client_info.read().unwrap().id.clone();
166            self.send(ClientToServerMessage::UnregisterClient { client_id }).await?;
167        }
168        self.inner.connected.store(false, Ordering::SeqCst);
169        self.inner.registered.store(false, Ordering::SeqCst);
170        {
171            let mut transport = self.inner.transport.lock().await;
172            transport.close().await?;
173        }
174        if let Some(task) = self.inner.receive_task.lock().await.take() {
175            task.abort();
176        }
177        Ok(())
178    }
179
180    async fn send(&self, message: ClientToServerMessage) -> Result<(), MdpClientError> {
181        let mut transport = self.inner.transport.lock().await;
182        transport.send(message).await
183    }
184}
185
186async fn process_messages(inner: Arc<MdpClientInner>, mut receiver: mpsc::UnboundedReceiver<ServerToClientMessage>) {
187    while let Some(message) = receiver.recv().await {
188        match message {
189            ServerToClientMessage::Ping { timestamp } => {
190                let mut transport = inner.transport.lock().await;
191                let _ = transport.send(ClientToServerMessage::Pong { timestamp }).await;
192            }
193            ServerToClientMessage::Pong { .. } => {}
194            ServerToClientMessage::CallClient(message) => {
195                let result = handle_invocation(&inner, &message).await;
196                let mut transport = inner.transport.lock().await;
197                let _ = transport.send(result).await;
198            }
199        }
200    }
201
202    inner.connected.store(false, Ordering::SeqCst);
203    inner.registered.store(false, Ordering::SeqCst);
204}
205
206async fn handle_invocation(
207    inner: &Arc<MdpClientInner>,
208    message: &CallClientRequest,
209) -> ClientToServerMessage {
210    let registry = inner.registry.read().unwrap().clone();
211    match registry.invoke(message).await {
212        Ok(data) => ClientToServerMessage::CallClientResult {
213            request_id: message.request_id.clone(),
214            ok: true,
215            data: Some(data),
216            error: None,
217        },
218        Err(error) => ClientToServerMessage::CallClientResult {
219            request_id: message.request_id.clone(),
220            ok: false,
221            data: None,
222            error: Some(crate::models::SerializedError::handler(error.to_string())),
223        },
224    }
225}