1use crate::{McpError, McpServer, Result};
7use async_trait::async_trait;
8use oxify_model::Workflow;
9use serde_json::{json, Value};
10use std::collections::HashMap;
11use std::future::Future;
12use std::pin::Pin;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15
16pub type WorkflowExecutor = Arc<
18 dyn Fn(Workflow, Value) -> Pin<Box<dyn Future<Output = Result<Value>> + Send>> + Send + Sync,
19>;
20
21#[derive(Debug, Clone)]
23pub struct WorkflowServerConfig {
24 pub tool_prefix: String,
26 pub include_descriptions: bool,
28 pub execution_timeout_secs: u64,
30}
31
32impl Default for WorkflowServerConfig {
33 fn default() -> Self {
34 Self {
35 tool_prefix: "workflow_".to_string(),
36 include_descriptions: true,
37 execution_timeout_secs: 300,
38 }
39 }
40}
41
42struct RegisteredWorkflow {
44 workflow: Workflow,
45 custom_input_schema: Option<Value>,
47 custom_description: Option<String>,
49}
50
51pub struct WorkflowServer {
74 workflows: Arc<RwLock<HashMap<String, RegisteredWorkflow>>>,
76 config: WorkflowServerConfig,
78 executor: Option<WorkflowExecutor>,
80}
81
82impl WorkflowServer {
83 pub fn new() -> Self {
85 Self {
86 workflows: Arc::new(RwLock::new(HashMap::new())),
87 config: WorkflowServerConfig::default(),
88 executor: None,
89 }
90 }
91
92 pub fn with_config(config: WorkflowServerConfig) -> Self {
94 Self {
95 workflows: Arc::new(RwLock::new(HashMap::new())),
96 config,
97 executor: None,
98 }
99 }
100
101 pub fn with_executor(mut self, executor: WorkflowExecutor) -> Self {
106 self.executor = Some(executor);
107 self
108 }
109
110 fn tool_name(&self, workflow: &Workflow) -> String {
112 let name = workflow
113 .metadata
114 .name
115 .to_lowercase()
116 .replace([' ', '-'], "_")
117 .chars()
118 .filter(|c| c.is_alphanumeric() || *c == '_')
119 .collect::<String>();
120
121 format!("{}{}", self.config.tool_prefix, name)
122 }
123
124 pub async fn register_workflow(&self, workflow: Workflow) -> Result<String> {
126 let tool_name = self.tool_name(&workflow);
127
128 let registered = RegisteredWorkflow {
129 workflow,
130 custom_input_schema: None,
131 custom_description: None,
132 };
133
134 let mut workflows = self.workflows.write().await;
135 workflows.insert(tool_name.clone(), registered);
136
137 tracing::info!(tool_name = %tool_name, "Registered workflow as MCP tool");
138
139 Ok(tool_name)
140 }
141
142 pub async fn register_workflow_with_schema(
144 &self,
145 workflow: Workflow,
146 input_schema: Value,
147 description: Option<String>,
148 ) -> Result<String> {
149 let tool_name = self.tool_name(&workflow);
150
151 let registered = RegisteredWorkflow {
152 workflow,
153 custom_input_schema: Some(input_schema),
154 custom_description: description,
155 };
156
157 let mut workflows = self.workflows.write().await;
158 workflows.insert(tool_name.clone(), registered);
159
160 tracing::info!(tool_name = %tool_name, "Registered workflow with custom schema as MCP tool");
161
162 Ok(tool_name)
163 }
164
165 pub async fn unregister_workflow(&self, tool_name: &str) -> Result<()> {
167 let mut workflows = self.workflows.write().await;
168
169 if workflows.remove(tool_name).is_some() {
170 tracing::info!(tool_name = %tool_name, "Unregistered workflow MCP tool");
171 Ok(())
172 } else {
173 Err(McpError::ToolNotFound(tool_name.to_string()))
174 }
175 }
176
177 pub async fn list_registered_tools(&self) -> Vec<String> {
179 let workflows = self.workflows.read().await;
180 workflows.keys().cloned().collect()
181 }
182
183 pub async fn get_workflow(&self, tool_name: &str) -> Option<Workflow> {
185 let workflows = self.workflows.read().await;
186 workflows.get(tool_name).map(|r| r.workflow.clone())
187 }
188
189 fn generate_input_schema(&self, workflow: &Workflow) -> Value {
191 let mut properties = serde_json::Map::new();
193 let mut required = Vec::new();
194
195 properties.insert(
197 "input".to_string(),
198 json!({
199 "type": "object",
200 "description": "Input variables for the workflow execution"
201 }),
202 );
203
204 for node in &workflow.nodes {
206 if let oxify_model::NodeKind::LLM(config) = &node.kind {
207 let template_vars = extract_template_variables(&config.prompt_template);
209 for var in template_vars {
210 if !properties.contains_key(&var) {
211 properties.insert(
212 var.clone(),
213 json!({
214 "type": "string",
215 "description": format!("Template variable for workflow: {}", var)
216 }),
217 );
218 required.push(json!(var));
219 }
220 }
221 }
222 }
223
224 json!({
225 "type": "object",
226 "properties": properties,
227 "required": required
228 })
229 }
230
231 fn generate_description(&self, workflow: &Workflow) -> String {
233 let base = workflow
234 .metadata
235 .description
236 .clone()
237 .unwrap_or_else(|| format!("Execute the {} workflow", workflow.metadata.name));
238
239 format!(
240 "{}. Nodes: {}, Edges: {}",
241 base,
242 workflow.nodes.len(),
243 workflow.edges.len()
244 )
245 }
246}
247
248impl Default for WorkflowServer {
249 fn default() -> Self {
250 Self::new()
251 }
252}
253
254fn extract_template_variables(text: &str) -> Vec<String> {
256 let mut vars = Vec::new();
257 let mut chars = text.chars().peekable();
258
259 while let Some(c) = chars.next() {
260 if c == '{' && chars.peek() == Some(&'{') {
261 chars.next(); let mut var_name = String::new();
263 while let Some(&next) = chars.peek() {
264 if next == '}' {
265 chars.next();
266 if chars.peek() == Some(&'}') {
267 chars.next();
268 if !var_name.is_empty() {
269 vars.push(var_name.trim().to_string());
270 }
271 break;
272 }
273 } else {
274 var_name.push(chars.next().unwrap());
275 }
276 }
277 }
278 }
279
280 vars
281}
282
283#[async_trait]
284impl McpServer for WorkflowServer {
285 async fn call_tool(&self, name: &str, arguments: Value) -> Result<Value> {
286 let workflows = self.workflows.read().await;
287
288 let registered = workflows
289 .get(name)
290 .ok_or_else(|| McpError::ToolNotFound(name.to_string()))?;
291
292 let workflow = registered.workflow.clone();
293 drop(workflows); tracing::info!(
296 tool_name = %name,
297 workflow_id = %workflow.metadata.id,
298 "Executing workflow via MCP"
299 );
300
301 if let Some(executor) = &self.executor {
303 let result = executor(workflow, arguments).await?;
304 return Ok(result);
305 }
306
307 Ok(json!({
309 "status": "accepted",
310 "workflow_id": workflow.metadata.id.to_string(),
311 "workflow_name": workflow.metadata.name,
312 "message": "Workflow execution request accepted. No executor configured - configure with_executor() for actual execution.",
313 "input": arguments,
314 "nodes": workflow.nodes.len(),
315 "edges": workflow.edges.len()
316 }))
317 }
318
319 async fn list_tools(&self) -> Result<Vec<Value>> {
320 let workflows = self.workflows.read().await;
321 let mut tools = Vec::new();
322
323 for (tool_name, registered) in workflows.iter() {
324 let description = registered
325 .custom_description
326 .clone()
327 .unwrap_or_else(|| self.generate_description(®istered.workflow));
328
329 let input_schema = registered
330 .custom_input_schema
331 .clone()
332 .unwrap_or_else(|| self.generate_input_schema(®istered.workflow));
333
334 tools.push(json!({
335 "name": tool_name,
336 "description": description,
337 "inputSchema": input_schema
338 }));
339 }
340
341 Ok(tools)
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348 use oxify_model::{Edge, Node, NodeKind};
349
350 fn create_test_workflow(name: &str) -> Workflow {
351 let mut workflow = Workflow::new(name.to_string());
352 workflow.metadata.description = Some("A test workflow".to_string());
353
354 let start = Node::new("Start".to_string(), NodeKind::Start);
355 let start_id = start.id;
356
357 let end = Node::new("End".to_string(), NodeKind::End);
358 let end_id = end.id;
359
360 workflow.add_node(start);
361 workflow.add_node(end);
362 workflow.add_edge(Edge::new(start_id, end_id));
363
364 workflow
365 }
366
367 #[tokio::test]
368 async fn test_register_workflow() {
369 let server = WorkflowServer::new();
370 let workflow = create_test_workflow("Test Workflow");
371
372 let tool_name = server.register_workflow(workflow).await.unwrap();
373
374 assert_eq!(tool_name, "workflow_test_workflow");
375
376 let tools = server.list_registered_tools().await;
377 assert_eq!(tools.len(), 1);
378 assert!(tools.contains(&"workflow_test_workflow".to_string()));
379 }
380
381 #[tokio::test]
382 async fn test_list_tools() {
383 let server = WorkflowServer::new();
384
385 server
386 .register_workflow(create_test_workflow("Workflow One"))
387 .await
388 .unwrap();
389 server
390 .register_workflow(create_test_workflow("Workflow Two"))
391 .await
392 .unwrap();
393
394 let tools = server.list_tools().await.unwrap();
395 assert_eq!(tools.len(), 2);
396
397 let names: Vec<&str> = tools.iter().map(|t| t["name"].as_str().unwrap()).collect();
398 assert!(names.contains(&"workflow_workflow_one"));
399 assert!(names.contains(&"workflow_workflow_two"));
400 }
401
402 #[tokio::test]
403 async fn test_call_tool_without_executor() {
404 let server = WorkflowServer::new();
405 let workflow = create_test_workflow("Execute Me");
406
407 let tool_name = server.register_workflow(workflow).await.unwrap();
408
409 let result = server
410 .call_tool(&tool_name, json!({"input": {"key": "value"}}))
411 .await
412 .unwrap();
413
414 assert_eq!(result["status"], "accepted");
415 assert_eq!(result["workflow_name"], "Execute Me");
416 }
417
418 #[tokio::test]
419 async fn test_call_tool_with_executor() {
420 let executor: WorkflowExecutor = Arc::new(|workflow, args| {
421 Box::pin(async move {
422 Ok(json!({
423 "executed": true,
424 "workflow_name": workflow.metadata.name,
425 "received_args": args
426 }))
427 })
428 });
429
430 let server = WorkflowServer::new().with_executor(executor);
431 let workflow = create_test_workflow("Custom Execute");
432
433 let tool_name = server.register_workflow(workflow).await.unwrap();
434
435 let result = server
436 .call_tool(&tool_name, json!({"test": "data"}))
437 .await
438 .unwrap();
439
440 assert_eq!(result["executed"], true);
441 assert_eq!(result["workflow_name"], "Custom Execute");
442 assert_eq!(result["received_args"]["test"], "data");
443 }
444
445 #[tokio::test]
446 async fn test_unregister_workflow() {
447 let server = WorkflowServer::new();
448 let workflow = create_test_workflow("To Remove");
449
450 let tool_name = server.register_workflow(workflow).await.unwrap();
451 assert_eq!(server.list_registered_tools().await.len(), 1);
452
453 server.unregister_workflow(&tool_name).await.unwrap();
454 assert_eq!(server.list_registered_tools().await.len(), 0);
455 }
456
457 #[tokio::test]
458 async fn test_tool_not_found() {
459 let server = WorkflowServer::new();
460
461 let result = server.call_tool("nonexistent_tool", json!({})).await;
462
463 assert!(result.is_err());
464 match result.unwrap_err() {
465 McpError::ToolNotFound(name) => assert_eq!(name, "nonexistent_tool"),
466 _ => panic!("Expected ToolNotFound error"),
467 }
468 }
469
470 #[tokio::test]
471 async fn test_custom_config() {
472 let config = WorkflowServerConfig {
473 tool_prefix: "oxify_".to_string(),
474 include_descriptions: true,
475 execution_timeout_secs: 60,
476 };
477
478 let server = WorkflowServer::with_config(config);
479 let workflow = create_test_workflow("My Workflow");
480
481 let tool_name = server.register_workflow(workflow).await.unwrap();
482
483 assert_eq!(tool_name, "oxify_my_workflow");
484 }
485
486 #[tokio::test]
487 async fn test_register_with_custom_schema() {
488 let server = WorkflowServer::new();
489 let workflow = create_test_workflow("Typed Workflow");
490
491 let custom_schema = json!({
492 "type": "object",
493 "properties": {
494 "message": {"type": "string"},
495 "count": {"type": "integer"}
496 },
497 "required": ["message"]
498 });
499
500 let tool_name = server
501 .register_workflow_with_schema(
502 workflow,
503 custom_schema.clone(),
504 Some("Custom description".to_string()),
505 )
506 .await
507 .unwrap();
508
509 let tools = server.list_tools().await.unwrap();
510 let tool = tools.iter().find(|t| t["name"] == tool_name).unwrap();
511
512 assert_eq!(tool["description"], "Custom description");
513 assert_eq!(tool["inputSchema"], custom_schema);
514 }
515
516 #[test]
517 fn test_extract_template_variables() {
518 let vars = extract_template_variables("Hello {{name}}, your order {{order_id}} is ready!");
519 assert_eq!(vars.len(), 2);
520 assert!(vars.contains(&"name".to_string()));
521 assert!(vars.contains(&"order_id".to_string()));
522 }
523
524 #[test]
525 fn test_extract_template_variables_empty() {
526 let vars = extract_template_variables("No variables here");
527 assert!(vars.is_empty());
528 }
529
530 #[test]
531 fn test_extract_template_variables_whitespace() {
532 let vars = extract_template_variables("Hello {{ name }}, {{ count }} items");
533 assert_eq!(vars.len(), 2);
534 assert!(vars.contains(&"name".to_string()));
535 assert!(vars.contains(&"count".to_string()));
536 }
537}