modeldriveprotocol_client/
client.rs1use std::sync::atomic::{AtomicBool, Ordering};
2use std::sync::{Arc, RwLock};
3
4use tokio::sync::{mpsc, Mutex};
5use tokio::task::JoinHandle;
6use url::Url;
7
8use crate::error::MdpClientError;
9use crate::models::{
10 AuthContext, ClientDescriptor, ClientInfo, ClientInfoOverride, EndpointOptions, HttpMethod,
11 PromptOptions, SkillOptions,
12};
13use crate::protocol::{CallClientRequest, ClientToServerMessage, ServerToClientMessage};
14use crate::registry::ProcedureRegistry;
15use crate::transport::{ClientTransport, HttpLoopClientTransport, WebSocketClientTransport};
16
17pub struct MdpClient {
18 inner: Arc<MdpClientInner>,
19}
20
21struct MdpClientInner {
22 client_info: RwLock<ClientInfo>,
23 auth: RwLock<Option<AuthContext>>,
24 registry: RwLock<ProcedureRegistry>,
25 transport: Mutex<Box<dyn ClientTransport>>,
26 receive_task: Mutex<Option<JoinHandle<()>>>,
27 connected: AtomicBool,
28 registered: AtomicBool,
29}
30
31impl MdpClient {
32 pub fn new(server_url: impl Into<String>, client: ClientInfo) -> Result<Self, MdpClientError> {
33 let server_url = server_url.into();
34 let url = Url::parse(&server_url).map_err(|error| MdpClientError::Transport(error.to_string()))?;
35 let transport: Box<dyn ClientTransport> = match url.scheme() {
36 "ws" | "wss" => Box::new(WebSocketClientTransport::new(server_url, None)),
37 "http" | "https" => Box::new(HttpLoopClientTransport::new(server_url, None)),
38 other => return Err(MdpClientError::Transport(format!("unsupported protocol `{other}`"))),
39 };
40
41 Ok(Self::with_transport(client, transport))
42 }
43
44 pub fn with_transport(client: ClientInfo, transport: Box<dyn ClientTransport>) -> Self {
45 Self {
46 inner: Arc::new(MdpClientInner {
47 client_info: RwLock::new(client),
48 auth: RwLock::new(None),
49 registry: RwLock::new(ProcedureRegistry::default()),
50 transport: Mutex::new(transport),
51 receive_task: Mutex::new(None),
52 connected: AtomicBool::new(false),
53 registered: AtomicBool::new(false),
54 }),
55 }
56 }
57
58 pub fn set_auth(&self, auth: Option<AuthContext>) {
59 *self.inner.auth.write().unwrap() = auth;
60 }
61
62 pub fn describe(&self) -> ClientDescriptor {
63 let client_info = self.inner.client_info.read().unwrap().clone();
64 self.inner.registry.read().unwrap().describe(&client_info)
65 }
66
67 pub fn expose_endpoint<H, Fut>(
68 &self,
69 path: impl Into<String>,
70 method: HttpMethod,
71 handler: H,
72 options: EndpointOptions,
73 ) -> Result<(), MdpClientError>
74 where
75 H: Send + Sync + 'static + Fn(crate::models::PathRequest, crate::models::PathInvocationContext) -> Fut,
76 Fut: std::future::Future<Output = Result<serde_json::Value, MdpClientError>> + Send + 'static,
77 {
78 self.inner
79 .registry
80 .write()
81 .unwrap()
82 .expose_endpoint(path, method, handler, options)
83 }
84
85 pub fn expose_skill_markdown(
86 &self,
87 path: impl Into<String>,
88 content: impl Into<String>,
89 options: SkillOptions,
90 ) -> Result<(), MdpClientError> {
91 self.inner
92 .registry
93 .write()
94 .unwrap()
95 .expose_skill_markdown(path, content, options)
96 }
97
98 pub fn expose_prompt_markdown(
99 &self,
100 path: impl Into<String>,
101 content: impl Into<String>,
102 options: PromptOptions,
103 ) -> Result<(), MdpClientError> {
104 self.inner
105 .registry
106 .write()
107 .unwrap()
108 .expose_prompt_markdown(path, content, options)
109 }
110
111 pub async fn connect(&self) -> Result<(), MdpClientError> {
112 let receiver = {
113 let mut transport = self.inner.transport.lock().await;
114 transport.connect().await?
115 };
116 self.inner.connected.store(true, Ordering::SeqCst);
117 let inner = self.inner.clone();
118 let task = tokio::spawn(async move {
119 process_messages(inner, receiver).await;
120 });
121 *self.inner.receive_task.lock().await = Some(task);
122 Ok(())
123 }
124
125 pub async fn register(
126 &self,
127 override_info: Option<ClientInfoOverride>,
128 ) -> Result<(), MdpClientError> {
129 if !self.inner.connected.load(Ordering::SeqCst) {
130 return Err(MdpClientError::NotConnected);
131 }
132
133 {
134 let current = self.inner.client_info.read().unwrap().clone();
135 *self.inner.client_info.write().unwrap() = current.apply_override(override_info);
136 }
137
138 let descriptor = self.describe();
139 let auth = self.inner.auth.read().unwrap().clone();
140 self.send(ClientToServerMessage::RegisterClient {
141 client: descriptor,
142 auth,
143 })
144 .await?;
145 self.inner.registered.store(true, Ordering::SeqCst);
146 Ok(())
147 }
148
149 pub async fn sync_catalog(&self) -> Result<(), MdpClientError> {
150 if !self.inner.connected.load(Ordering::SeqCst) {
151 return Err(MdpClientError::NotConnected);
152 }
153 if !self.inner.registered.load(Ordering::SeqCst) {
154 return Err(MdpClientError::NotRegistered);
155 }
156
157 let client_id = self.inner.client_info.read().unwrap().id.clone();
158 let paths = self.inner.registry.read().unwrap().describe_paths();
159 self.send(ClientToServerMessage::UpdateClientCatalog { client_id, paths })
160 .await
161 }
162
163 pub async fn disconnect(&self) -> Result<(), MdpClientError> {
164 if self.inner.connected.load(Ordering::SeqCst) && self.inner.registered.load(Ordering::SeqCst) {
165 let client_id = self.inner.client_info.read().unwrap().id.clone();
166 self.send(ClientToServerMessage::UnregisterClient { client_id }).await?;
167 }
168 self.inner.connected.store(false, Ordering::SeqCst);
169 self.inner.registered.store(false, Ordering::SeqCst);
170 {
171 let mut transport = self.inner.transport.lock().await;
172 transport.close().await?;
173 }
174 if let Some(task) = self.inner.receive_task.lock().await.take() {
175 task.abort();
176 }
177 Ok(())
178 }
179
180 async fn send(&self, message: ClientToServerMessage) -> Result<(), MdpClientError> {
181 let mut transport = self.inner.transport.lock().await;
182 transport.send(message).await
183 }
184}
185
186async fn process_messages(inner: Arc<MdpClientInner>, mut receiver: mpsc::UnboundedReceiver<ServerToClientMessage>) {
187 while let Some(message) = receiver.recv().await {
188 match message {
189 ServerToClientMessage::Ping { timestamp } => {
190 let mut transport = inner.transport.lock().await;
191 let _ = transport.send(ClientToServerMessage::Pong { timestamp }).await;
192 }
193 ServerToClientMessage::Pong { .. } => {}
194 ServerToClientMessage::CallClient(message) => {
195 let result = handle_invocation(&inner, &message).await;
196 let mut transport = inner.transport.lock().await;
197 let _ = transport.send(result).await;
198 }
199 }
200 }
201
202 inner.connected.store(false, Ordering::SeqCst);
203 inner.registered.store(false, Ordering::SeqCst);
204}
205
206async fn handle_invocation(
207 inner: &Arc<MdpClientInner>,
208 message: &CallClientRequest,
209) -> ClientToServerMessage {
210 let registry = inner.registry.read().unwrap().clone();
211 match registry.invoke(message).await {
212 Ok(data) => ClientToServerMessage::CallClientResult {
213 request_id: message.request_id.clone(),
214 ok: true,
215 data: Some(data),
216 error: None,
217 },
218 Err(error) => ClientToServerMessage::CallClientResult {
219 request_id: message.request_id.clone(),
220 ok: false,
221 data: None,
222 error: Some(crate::models::SerializedError::handler(error.to_string())),
223 },
224 }
225}