modeldriveprotocol_client/
registry.rs1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use serde_json::{json, Map, Value};
7
8use crate::error::MdpClientError;
9use crate::models::{
10 ClientDescriptor, ClientInfo, EndpointOptions, HttpMethod, PathDescriptor, PathInvocationContext,
11 PathRequest, PromptOptions, SkillOptions,
12};
13use crate::path_utils::{
14 compare_path_specificity, is_path_pattern, is_prompt_path, is_skill_path, match_path_pattern,
15};
16use crate::protocol::CallClientRequest;
17
18type HandlerFuture = Pin<Box<dyn Future<Output = Result<Value, MdpClientError>> + Send>>;
19
20#[async_trait]
21pub trait PathHandler: Send + Sync {
22 async fn handle(
23 &self,
24 request: PathRequest,
25 context: PathInvocationContext,
26 ) -> Result<Value, MdpClientError>;
27}
28
29#[async_trait]
30impl<F, Fut> PathHandler for F
31where
32 F: Send + Sync + 'static + Fn(PathRequest, PathInvocationContext) -> Fut,
33 Fut: Future<Output = Result<Value, MdpClientError>> + Send + 'static,
34{
35 async fn handle(
36 &self,
37 request: PathRequest,
38 context: PathInvocationContext,
39 ) -> Result<Value, MdpClientError> {
40 (self)(request, context).await
41 }
42}
43
44#[derive(Clone)]
45struct ProcedureEntry {
46 descriptor: PathDescriptor,
47 handler: Arc<dyn PathHandler>,
48}
49
50#[derive(Clone)]
51struct ResolvedProcedure {
52 descriptor: PathDescriptor,
53 handler: Arc<dyn PathHandler>,
54 params: Map<String, Value>,
55 specificity: Vec<i32>,
56}
57
58#[derive(Default, Clone)]
59pub struct ProcedureRegistry {
60 entries: Vec<ProcedureEntry>,
61}
62
63impl ProcedureRegistry {
64 pub fn expose_endpoint<H, Fut>(
65 &mut self,
66 path: impl Into<String>,
67 method: HttpMethod,
68 handler: H,
69 options: EndpointOptions,
70 ) -> Result<(), MdpClientError>
71 where
72 H: Send + Sync + 'static + Fn(PathRequest, PathInvocationContext) -> Fut,
73 Fut: Future<Output = Result<Value, MdpClientError>> + Send + 'static,
74 {
75 let descriptor = PathDescriptor::Endpoint {
76 path: path.into(),
77 method,
78 description: options.description,
79 input_schema: options.input_schema,
80 output_schema: options.output_schema,
81 content_type: options.content_type,
82 };
83 self.register(descriptor, Arc::new(handler))
84 }
85
86 pub fn expose_skill_markdown(
87 &mut self,
88 path: impl Into<String>,
89 content: impl Into<String>,
90 options: SkillOptions,
91 ) -> Result<(), MdpClientError> {
92 let content = content.into();
93 let description = options
94 .description
95 .or_else(|| derive_markdown_description(&content));
96 let descriptor = PathDescriptor::Skill {
97 path: path.into(),
98 description,
99 content_type: options.content_type,
100 };
101 self.register(
102 descriptor,
103 Arc::new(move |_request: PathRequest, _context: PathInvocationContext| {
104 let content = content.clone();
105 Box::pin(async move { Ok(Value::String(content)) }) as HandlerFuture
106 }),
107 )
108 }
109
110 pub fn expose_prompt_markdown(
111 &mut self,
112 path: impl Into<String>,
113 content: impl Into<String>,
114 options: PromptOptions,
115 ) -> Result<(), MdpClientError> {
116 let content = content.into();
117 let description = options
118 .description
119 .or_else(|| derive_markdown_description(&content));
120 let descriptor = PathDescriptor::Prompt {
121 path: path.into(),
122 description,
123 input_schema: options.input_schema,
124 output_schema: options.output_schema,
125 };
126 self.register(
127 descriptor,
128 Arc::new(move |_request: PathRequest, _context: PathInvocationContext| {
129 let content = content.clone();
130 Box::pin(async move { Ok(json!({"messages": [{"role": "user", "content": content}]})) })
131 as HandlerFuture
132 }),
133 )
134 }
135
136 pub fn describe_paths(&self) -> Vec<PathDescriptor> {
137 self.entries.iter().map(|entry| entry.descriptor.clone()).collect()
138 }
139
140 pub fn describe(&self, client: &ClientInfo) -> ClientDescriptor {
141 ClientDescriptor::from_info(client, self.describe_paths())
142 }
143
144 pub fn unexpose(&mut self, path: &str, method: Option<HttpMethod>) -> Result<bool, MdpClientError> {
145 assert_path_pattern(path)?;
146 let method = method.as_ref();
147 if let Some(index) = self.entries.iter().position(|entry| match &entry.descriptor {
148 PathDescriptor::Endpoint {
149 path: descriptor_path,
150 method: descriptor_method,
151 ..
152 } => descriptor_path == path && method.map(|value| value == descriptor_method).unwrap_or(false),
153 PathDescriptor::Skill { path: descriptor_path, .. }
154 | PathDescriptor::Prompt { path: descriptor_path, .. } => descriptor_path == path,
155 }) {
156 self.entries.remove(index);
157 return Ok(true);
158 }
159 Ok(false)
160 }
161
162 pub async fn invoke(&self, message: &CallClientRequest) -> Result<Value, MdpClientError> {
163 let entry = self
164 .resolve_entry(&message.method, &message.path)
165 .ok_or_else(|| MdpClientError::UnknownPath {
166 path: message.path.clone(),
167 method: message.method.as_str().to_string(),
168 })?;
169
170 let request = PathRequest {
171 params: entry.params,
172 queries: message.query.clone().unwrap_or_default(),
173 body: message.body.clone(),
174 headers: message.headers.clone().unwrap_or_default(),
175 };
176 let context = PathInvocationContext {
177 request_id: message.request_id.clone(),
178 client_id: message.client_id.clone(),
179 path_type: entry.descriptor.descriptor_type().to_string(),
180 method: message.method.clone(),
181 path: message.path.clone(),
182 auth: message.auth.clone(),
183 };
184 entry.handler.handle(request, context).await
185 }
186
187 fn register(&mut self, descriptor: PathDescriptor, handler: Arc<dyn PathHandler>) -> Result<(), MdpClientError> {
188 assert_path_pattern(descriptor.path())?;
189 assert_descriptor_path_shape(&descriptor)?;
190 let key = registration_key(&descriptor);
191 let entry = ProcedureEntry { descriptor, handler };
192 if let Some(index) = self
193 .entries
194 .iter()
195 .position(|current| registration_key(¤t.descriptor) == key)
196 {
197 self.entries[index] = entry;
198 } else {
199 self.entries.push(entry);
200 }
201 Ok(())
202 }
203
204 fn resolve_entry(&self, method: &HttpMethod, path: &str) -> Option<ResolvedProcedure> {
205 let mut best_match: Option<ResolvedProcedure> = None;
206 for entry in &self.entries {
207 if !matches_method(&entry.descriptor, method) {
208 continue;
209 }
210 let Some(path_match) = match_path_pattern(entry.descriptor.path(), path) else {
211 continue;
212 };
213
214 if best_match
215 .as_ref()
216 .map(|current| compare_path_specificity(&path_match.specificity, ¤t.specificity) > 0)
217 .unwrap_or(true)
218 {
219 best_match = Some(ResolvedProcedure {
220 descriptor: entry.descriptor.clone(),
221 handler: entry.handler.clone(),
222 params: path_match.params,
223 specificity: path_match.specificity,
224 });
225 }
226 }
227 best_match
228 }
229}
230
231fn registration_key(descriptor: &PathDescriptor) -> String {
232 match descriptor {
233 PathDescriptor::Endpoint { path, method, .. } => format!("{} {}", method.as_str(), path),
234 PathDescriptor::Skill { path, .. } | PathDescriptor::Prompt { path, .. } => path.clone(),
235 }
236}
237
238fn matches_method(descriptor: &PathDescriptor, method: &HttpMethod) -> bool {
239 match descriptor {
240 PathDescriptor::Endpoint {
241 method: descriptor_method,
242 ..
243 } => descriptor_method == method,
244 PathDescriptor::Skill { .. } | PathDescriptor::Prompt { .. } => *method == HttpMethod::Get,
245 }
246}
247
248fn assert_path_pattern(path: &str) -> Result<(), MdpClientError> {
249 if is_path_pattern(path) {
250 Ok(())
251 } else {
252 Err(MdpClientError::InvalidPath(path.to_string()))
253 }
254}
255
256fn assert_descriptor_path_shape(descriptor: &PathDescriptor) -> Result<(), MdpClientError> {
257 match descriptor {
258 PathDescriptor::Endpoint { path, .. } => {
259 if is_skill_path(path) || is_prompt_path(path) {
260 return Err(MdpClientError::InvalidPath(path.clone()));
261 }
262 }
263 PathDescriptor::Skill { path, .. } => {
264 if !is_skill_path(path) {
265 return Err(MdpClientError::InvalidPath(path.clone()));
266 }
267 }
268 PathDescriptor::Prompt { path, .. } => {
269 if !is_prompt_path(path) {
270 return Err(MdpClientError::InvalidPath(path.clone()));
271 }
272 }
273 }
274 Ok(())
275}
276
277fn derive_markdown_description(content: &str) -> Option<String> {
278 let mut paragraph = Vec::new();
279 for line in content.lines().map(str::trim) {
280 if line.is_empty() {
281 if !paragraph.is_empty() {
282 break;
283 }
284 continue;
285 }
286 if line.starts_with('#') {
287 continue;
288 }
289 paragraph.push(line);
290 }
291 if paragraph.is_empty() {
292 None
293 } else {
294 Some(paragraph.join(" "))
295 }
296}