Skip to main content

modeldriveprotocol_client/
registry.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use serde_json::{json, Map, Value};
7
8use crate::error::MdpClientError;
9use crate::models::{
10    ClientDescriptor, ClientInfo, EndpointOptions, HttpMethod, PathDescriptor, PathInvocationContext,
11    PathRequest, PromptOptions, SkillOptions,
12};
13use crate::path_utils::{
14    compare_path_specificity, is_path_pattern, is_prompt_path, is_skill_path, match_path_pattern,
15};
16use crate::protocol::CallClientRequest;
17
18type HandlerFuture = Pin<Box<dyn Future<Output = Result<Value, MdpClientError>> + Send>>;
19
20#[async_trait]
21pub trait PathHandler: Send + Sync {
22    async fn handle(
23        &self,
24        request: PathRequest,
25        context: PathInvocationContext,
26    ) -> Result<Value, MdpClientError>;
27}
28
29#[async_trait]
30impl<F, Fut> PathHandler for F
31where
32    F: Send + Sync + 'static + Fn(PathRequest, PathInvocationContext) -> Fut,
33    Fut: Future<Output = Result<Value, MdpClientError>> + Send + 'static,
34{
35    async fn handle(
36        &self,
37        request: PathRequest,
38        context: PathInvocationContext,
39    ) -> Result<Value, MdpClientError> {
40        (self)(request, context).await
41    }
42}
43
44#[derive(Clone)]
45struct ProcedureEntry {
46    descriptor: PathDescriptor,
47    handler: Arc<dyn PathHandler>,
48}
49
50#[derive(Clone)]
51struct ResolvedProcedure {
52    descriptor: PathDescriptor,
53    handler: Arc<dyn PathHandler>,
54    params: Map<String, Value>,
55    specificity: Vec<i32>,
56}
57
58#[derive(Default, Clone)]
59pub struct ProcedureRegistry {
60    entries: Vec<ProcedureEntry>,
61}
62
63impl ProcedureRegistry {
64    pub fn expose_endpoint<H, Fut>(
65        &mut self,
66        path: impl Into<String>,
67        method: HttpMethod,
68        handler: H,
69        options: EndpointOptions,
70    ) -> Result<(), MdpClientError>
71    where
72        H: Send + Sync + 'static + Fn(PathRequest, PathInvocationContext) -> Fut,
73        Fut: Future<Output = Result<Value, MdpClientError>> + Send + 'static,
74    {
75        let descriptor = PathDescriptor::Endpoint {
76            path: path.into(),
77            method,
78            description: options.description,
79            input_schema: options.input_schema,
80            output_schema: options.output_schema,
81            content_type: options.content_type,
82        };
83        self.register(descriptor, Arc::new(handler))
84    }
85
86    pub fn expose_skill_markdown(
87        &mut self,
88        path: impl Into<String>,
89        content: impl Into<String>,
90        options: SkillOptions,
91    ) -> Result<(), MdpClientError> {
92        let content = content.into();
93        let description = options
94            .description
95            .or_else(|| derive_markdown_description(&content));
96        let descriptor = PathDescriptor::Skill {
97            path: path.into(),
98            description,
99            content_type: options.content_type,
100        };
101        self.register(
102            descriptor,
103            Arc::new(move |_request: PathRequest, _context: PathInvocationContext| {
104                let content = content.clone();
105                Box::pin(async move { Ok(Value::String(content)) }) as HandlerFuture
106            }),
107        )
108    }
109
110    pub fn expose_prompt_markdown(
111        &mut self,
112        path: impl Into<String>,
113        content: impl Into<String>,
114        options: PromptOptions,
115    ) -> Result<(), MdpClientError> {
116        let content = content.into();
117        let description = options
118            .description
119            .or_else(|| derive_markdown_description(&content));
120        let descriptor = PathDescriptor::Prompt {
121            path: path.into(),
122            description,
123            input_schema: options.input_schema,
124            output_schema: options.output_schema,
125        };
126        self.register(
127            descriptor,
128            Arc::new(move |_request: PathRequest, _context: PathInvocationContext| {
129                let content = content.clone();
130                Box::pin(async move { Ok(json!({"messages": [{"role": "user", "content": content}]})) })
131                    as HandlerFuture
132            }),
133        )
134    }
135
136    pub fn describe_paths(&self) -> Vec<PathDescriptor> {
137        self.entries.iter().map(|entry| entry.descriptor.clone()).collect()
138    }
139
140    pub fn describe(&self, client: &ClientInfo) -> ClientDescriptor {
141        ClientDescriptor::from_info(client, self.describe_paths())
142    }
143
144    pub fn unexpose(&mut self, path: &str, method: Option<HttpMethod>) -> Result<bool, MdpClientError> {
145        assert_path_pattern(path)?;
146        let method = method.as_ref();
147        if let Some(index) = self.entries.iter().position(|entry| match &entry.descriptor {
148            PathDescriptor::Endpoint {
149                path: descriptor_path,
150                method: descriptor_method,
151                ..
152            } => descriptor_path == path && method.map(|value| value == descriptor_method).unwrap_or(false),
153            PathDescriptor::Skill { path: descriptor_path, .. }
154            | PathDescriptor::Prompt { path: descriptor_path, .. } => descriptor_path == path,
155        }) {
156            self.entries.remove(index);
157            return Ok(true);
158        }
159        Ok(false)
160    }
161
162    pub async fn invoke(&self, message: &CallClientRequest) -> Result<Value, MdpClientError> {
163        let entry = self
164            .resolve_entry(&message.method, &message.path)
165            .ok_or_else(|| MdpClientError::UnknownPath {
166                path: message.path.clone(),
167                method: message.method.as_str().to_string(),
168            })?;
169
170        let request = PathRequest {
171            params: entry.params,
172            queries: message.query.clone().unwrap_or_default(),
173            body: message.body.clone(),
174            headers: message.headers.clone().unwrap_or_default(),
175        };
176        let context = PathInvocationContext {
177            request_id: message.request_id.clone(),
178            client_id: message.client_id.clone(),
179            path_type: entry.descriptor.descriptor_type().to_string(),
180            method: message.method.clone(),
181            path: message.path.clone(),
182            auth: message.auth.clone(),
183        };
184        entry.handler.handle(request, context).await
185    }
186
187    fn register(&mut self, descriptor: PathDescriptor, handler: Arc<dyn PathHandler>) -> Result<(), MdpClientError> {
188        assert_path_pattern(descriptor.path())?;
189        assert_descriptor_path_shape(&descriptor)?;
190        let key = registration_key(&descriptor);
191        let entry = ProcedureEntry { descriptor, handler };
192        if let Some(index) = self
193            .entries
194            .iter()
195            .position(|current| registration_key(&current.descriptor) == key)
196        {
197            self.entries[index] = entry;
198        } else {
199            self.entries.push(entry);
200        }
201        Ok(())
202    }
203
204    fn resolve_entry(&self, method: &HttpMethod, path: &str) -> Option<ResolvedProcedure> {
205        let mut best_match: Option<ResolvedProcedure> = None;
206        for entry in &self.entries {
207            if !matches_method(&entry.descriptor, method) {
208                continue;
209            }
210            let Some(path_match) = match_path_pattern(entry.descriptor.path(), path) else {
211                continue;
212            };
213
214            if best_match
215                .as_ref()
216                .map(|current| compare_path_specificity(&path_match.specificity, &current.specificity) > 0)
217                .unwrap_or(true)
218            {
219                best_match = Some(ResolvedProcedure {
220                    descriptor: entry.descriptor.clone(),
221                    handler: entry.handler.clone(),
222                    params: path_match.params,
223                    specificity: path_match.specificity,
224                });
225            }
226        }
227        best_match
228    }
229}
230
231fn registration_key(descriptor: &PathDescriptor) -> String {
232    match descriptor {
233        PathDescriptor::Endpoint { path, method, .. } => format!("{} {}", method.as_str(), path),
234        PathDescriptor::Skill { path, .. } | PathDescriptor::Prompt { path, .. } => path.clone(),
235    }
236}
237
238fn matches_method(descriptor: &PathDescriptor, method: &HttpMethod) -> bool {
239    match descriptor {
240        PathDescriptor::Endpoint {
241            method: descriptor_method,
242            ..
243        } => descriptor_method == method,
244        PathDescriptor::Skill { .. } | PathDescriptor::Prompt { .. } => *method == HttpMethod::Get,
245    }
246}
247
248fn assert_path_pattern(path: &str) -> Result<(), MdpClientError> {
249    if is_path_pattern(path) {
250        Ok(())
251    } else {
252        Err(MdpClientError::InvalidPath(path.to_string()))
253    }
254}
255
256fn assert_descriptor_path_shape(descriptor: &PathDescriptor) -> Result<(), MdpClientError> {
257    match descriptor {
258        PathDescriptor::Endpoint { path, .. } => {
259            if is_skill_path(path) || is_prompt_path(path) {
260                return Err(MdpClientError::InvalidPath(path.clone()));
261            }
262        }
263        PathDescriptor::Skill { path, .. } => {
264            if !is_skill_path(path) {
265                return Err(MdpClientError::InvalidPath(path.clone()));
266            }
267        }
268        PathDescriptor::Prompt { path, .. } => {
269            if !is_prompt_path(path) {
270                return Err(MdpClientError::InvalidPath(path.clone()));
271            }
272        }
273    }
274    Ok(())
275}
276
277fn derive_markdown_description(content: &str) -> Option<String> {
278    let mut paragraph = Vec::new();
279    for line in content.lines().map(str::trim) {
280        if line.is_empty() {
281            if !paragraph.is_empty() {
282                break;
283            }
284            continue;
285        }
286        if line.starts_with('#') {
287            continue;
288        }
289        paragraph.push(line);
290    }
291    if paragraph.is_empty() {
292        None
293    } else {
294        Some(paragraph.join(" "))
295    }
296}