1pub mod config;
46pub mod default_executor;
47pub mod executor;
48pub mod hooks;
49pub mod retry;
50pub mod timeout;
51
52use serde::{Deserialize, Serialize};
53use serde_json::Value;
54use std::collections::HashMap;
55use thulp_core::ToolResult;
56
57use thulp_core::{ToolCall, Transport};
58
59pub use config::{
60 BackoffStrategy, ExecutionConfig, RetryConfig, RetryableError, TimeoutAction, TimeoutConfig,
61};
62pub use default_executor::DefaultSkillExecutor;
63pub use executor::{ExecutionContext, SkillExecutor, StepResult};
64pub use hooks::{CompositeHooks, ExecutionHooks, NoOpHooks, TracingHooks};
65pub use retry::{calculate_delay, is_error_retryable, with_retry, RetryError};
66pub use timeout::{with_timeout, with_timeout_infallible, TimeoutError};
67
68#[cfg(test)]
69use async_trait::async_trait;
70
71pub type Result<T> = std::result::Result<T, SkillError>;
73
74#[derive(Debug, thiserror::Error)]
76pub enum SkillError {
77 #[error("Execution error: {0}")]
78 Execution(String),
79
80 #[error("Skill not found: {0}")]
81 NotFound(String),
82
83 #[error("Invalid configuration: {0}")]
84 InvalidConfig(String),
85
86 #[error("Step '{step}' timed out after {duration:?}")]
87 StepTimeout {
88 step: String,
89 duration: std::time::Duration,
90 },
91
92 #[error("Skill timed out after {duration:?}")]
93 SkillTimeout { duration: std::time::Duration },
94
95 #[error("Step '{step}' failed after {attempts} attempts: {message}")]
96 RetryExhausted {
97 step: String,
98 attempts: usize,
99 message: String,
100 },
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize, Default)]
105pub struct SkillStep {
106 pub name: String,
108
109 pub tool: String,
111
112 #[serde(default)]
114 pub arguments: Value,
115
116 #[serde(default)]
118 pub continue_on_error: bool,
119
120 #[serde(default)]
122 pub timeout_secs: Option<u64>,
123
124 #[serde(default)]
126 pub max_retries: Option<usize>,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct Skill {
132 pub name: String,
134
135 pub description: String,
137
138 #[serde(default)]
140 pub inputs: Vec<String>,
141
142 pub steps: Vec<SkillStep>,
144}
145
146impl Skill {
147 pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
149 Self {
150 name: name.into(),
151 description: description.into(),
152 inputs: Vec::new(),
153 steps: Vec::new(),
154 }
155 }
156
157 pub fn with_input(mut self, input: impl Into<String>) -> Self {
159 self.inputs.push(input.into());
160 self
161 }
162
163 pub fn with_step(mut self, step: SkillStep) -> Self {
165 self.steps.push(step);
166 self
167 }
168}
169
170impl Skill {
171 pub async fn execute<T: Transport>(
173 &self,
174 transport: &T,
175 input_args: &HashMap<String, serde_json::Value>,
176 ) -> Result<SkillResult> {
177 self.execute_with_config(transport, input_args, &ExecutionConfig::default())
178 .await
179 }
180
181 pub async fn execute_with_config<T: Transport>(
206 &self,
207 transport: &T,
208 input_args: &HashMap<String, serde_json::Value>,
209 config: &ExecutionConfig,
210 ) -> Result<SkillResult> {
211 let skill_timeout = config.timeout.skill_timeout;
212
213 let result = tokio::time::timeout(skill_timeout, async {
215 self.execute_steps_with_config(transport, input_args, config)
216 .await
217 })
218 .await;
219
220 match result {
221 Ok(inner_result) => inner_result,
222 Err(_elapsed) => {
223 match config.timeout.timeout_action {
225 TimeoutAction::Fail => Err(SkillError::SkillTimeout {
226 duration: skill_timeout,
227 }),
228 TimeoutAction::Skip | TimeoutAction::Partial => {
229 Ok(SkillResult {
231 success: false,
232 step_results: vec![],
233 output: None,
234 error: Some(format!("Skill timed out after {:?}", skill_timeout)),
235 })
236 }
237 }
238 }
239 }
240 }
241
242 async fn execute_steps_with_config<T: Transport>(
244 &self,
245 transport: &T,
246 input_args: &HashMap<String, serde_json::Value>,
247 config: &ExecutionConfig,
248 ) -> Result<SkillResult> {
249 use std::time::Duration;
250
251 let mut step_results = Vec::new();
252 let mut context = input_args.clone();
253
254 for step in &self.steps {
255 let step_timeout = step
257 .timeout_secs
258 .map(Duration::from_secs)
259 .unwrap_or(config.timeout.step_timeout);
260
261 let max_retries = step.max_retries.unwrap_or(config.retry.max_retries);
263 let step_retry_config = RetryConfig {
264 max_retries,
265 ..config.retry.clone()
266 };
267
268 let prepared_args = self.prepare_arguments(&step.arguments, &context)?;
270
271 let tool_call = ToolCall {
272 tool: step.tool.clone(),
273 arguments: prepared_args,
274 };
275
276 let step_result = self
278 .execute_step_with_retry_timeout(
279 transport,
280 &tool_call,
281 &step.name,
282 step_timeout,
283 &step_retry_config,
284 )
285 .await;
286
287 match step_result {
288 Ok(result) => {
289 step_results.push((step.name.clone(), result.clone()));
290
291 context.insert(
293 step.name.clone(),
294 result.data.clone().unwrap_or(Value::Null),
295 );
296
297 if step_results.len() == self.steps.len() {
299 return Ok(SkillResult {
300 success: true,
301 step_results,
302 output: result.data,
303 error: None,
304 });
305 }
306 }
307 Err(e) => {
308 if step.continue_on_error {
309 step_results.push((step.name.clone(), ToolResult::failure(e.to_string())));
311 } else {
312 match &config.timeout.timeout_action {
314 TimeoutAction::Skip => {
315 step_results
316 .push((step.name.clone(), ToolResult::failure(e.to_string())));
317 }
319 TimeoutAction::Partial => {
320 return Ok(SkillResult {
321 success: false,
322 step_results,
323 output: None,
324 error: Some(e.to_string()),
325 });
326 }
327 TimeoutAction::Fail => {
328 return Err(e);
329 }
330 }
331 }
332 }
333 }
334 }
335
336 Ok(SkillResult {
337 success: true,
338 step_results,
339 output: None,
340 error: None,
341 })
342 }
343
344 async fn execute_step_with_retry_timeout<T: Transport>(
346 &self,
347 transport: &T,
348 tool_call: &ToolCall,
349 step_name: &str,
350 timeout: std::time::Duration,
351 retry_config: &RetryConfig,
352 ) -> Result<ToolResult> {
353 let mut attempts = 0;
354
355 loop {
356 attempts += 1;
357
358 let result = tokio::time::timeout(timeout, transport.call(tool_call)).await;
360
361 match result {
362 Ok(Ok(tool_result)) => {
363 return Ok(tool_result);
365 }
366 Ok(Err(e)) => {
367 let error_msg = e.to_string();
369
370 if attempts > retry_config.max_retries
371 || !is_error_retryable(&error_msg, retry_config)
372 {
373 return Err(SkillError::RetryExhausted {
374 step: step_name.to_string(),
375 attempts,
376 message: error_msg,
377 });
378 }
379
380 let delay = calculate_delay(retry_config, attempts);
381 tracing::warn!(
382 step = step_name,
383 attempt = attempts,
384 max_retries = retry_config.max_retries,
385 delay_ms = delay.as_millis() as u64,
386 error = %e,
387 "Retrying step after error"
388 );
389 tokio::time::sleep(delay).await;
390 }
391 Err(_elapsed) => {
392 if attempts > retry_config.max_retries
394 || !retry_config
395 .retryable_errors
396 .contains(&RetryableError::Timeout)
397 {
398 return Err(SkillError::StepTimeout {
399 step: step_name.to_string(),
400 duration: timeout,
401 });
402 }
403
404 let delay = calculate_delay(retry_config, attempts);
405 tracing::warn!(
406 step = step_name,
407 attempt = attempts,
408 max_retries = retry_config.max_retries,
409 delay_ms = delay.as_millis() as u64,
410 "Retrying step after timeout"
411 );
412 tokio::time::sleep(delay).await;
413 }
414 }
415 }
416 }
417
418 fn prepare_arguments(
420 &self,
421 args: &serde_json::Value,
422 context: &HashMap<String, serde_json::Value>,
423 ) -> Result<serde_json::Value> {
424 let args_str = serde_json::to_string(args)
427 .map_err(|e| SkillError::InvalidConfig(format!("Failed to serialize args: {}", e)))?;
428
429 let mut processed_str = args_str.clone();
430
431 for (key, value) in context {
433 let placeholder = format!("{{{{{}}}}}", key);
434 let value_str = serde_json::to_string(value).map_err(|e| {
435 SkillError::InvalidConfig(format!("Failed to serialize value: {}", e))
436 })?;
437 processed_str = processed_str.replace(&placeholder, &value_str);
438 }
439
440 serde_json::from_str(&processed_str).map_err(|e| {
441 SkillError::InvalidConfig(format!("Failed to parse processed args: {}", e))
442 })
443 }
444}
445
446#[derive(Debug, Clone, Serialize, Deserialize)]
448pub struct SkillResult {
449 pub success: bool,
451
452 pub step_results: Vec<(String, ToolResult)>,
454
455 pub output: Option<Value>,
457
458 pub error: Option<String>,
460}
461
462#[derive(Debug, Default)]
464pub struct SkillRegistry {
465 skills: HashMap<String, Skill>,
466}
467
468impl SkillRegistry {
469 pub fn new() -> Self {
471 Self::default()
472 }
473
474 pub fn register(&mut self, skill: Skill) {
476 self.skills.insert(skill.name.clone(), skill);
477 }
478
479 pub fn get(&self, name: &str) -> Option<&Skill> {
481 self.skills.get(name)
482 }
483
484 pub fn list(&self) -> Vec<String> {
486 self.skills.keys().cloned().collect()
487 }
488
489 pub fn unregister(&mut self, name: &str) -> Option<Skill> {
491 self.skills.remove(name)
492 }
493}
494
495#[cfg(test)]
497struct MockTransport {
498 responses: HashMap<String, ToolResult>,
499}
500
501#[cfg(test)]
502#[async_trait]
503impl Transport for MockTransport {
504 async fn connect(&mut self) -> thulp_core::Result<()> {
505 Ok(())
506 }
507
508 async fn disconnect(&mut self) -> thulp_core::Result<()> {
509 Ok(())
510 }
511
512 fn is_connected(&self) -> bool {
513 true
514 }
515
516 async fn list_tools(&self) -> thulp_core::Result<Vec<thulp_core::ToolDefinition>> {
517 Ok(vec![])
518 }
519
520 async fn call(&self, call: &ToolCall) -> thulp_core::Result<ToolResult> {
521 if let Some(result) = self.responses.get(&call.tool) {
522 Ok(result.clone())
523 } else {
524 Err(thulp_core::Error::ToolNotFound(call.tool.clone()))
525 }
526 }
527}
528
529#[cfg(test)]
530impl MockTransport {
531 fn new() -> Self {
532 Self {
533 responses: HashMap::new(),
534 }
535 }
536
537 fn with_response(mut self, tool_name: &str, result: ToolResult) -> Self {
538 self.responses.insert(tool_name.to_string(), result);
539 self
540 }
541}
542
543#[cfg(test)]
544mod tests {
545 use super::*;
546 use std::time::Duration;
547
548 #[test]
549 fn test_skill_creation() {
550 let skill = Skill::new("test_skill", "A test skill");
551 assert_eq!(skill.name, "test_skill");
552 assert_eq!(skill.description, "A test skill");
553 }
554
555 #[test]
556 fn test_skill_builder() {
557 let skill = Skill::new("search_and_summarize", "Search and summarize results")
558 .with_input("query")
559 .with_step(SkillStep {
560 name: "search".to_string(),
561 tool: "web_search".to_string(),
562 arguments: serde_json::json!({"query": "{{query}}"}),
563 continue_on_error: false,
564 timeout_secs: None,
565 max_retries: None,
566 })
567 .with_step(SkillStep {
568 name: "summarize".to_string(),
569 tool: "summarize".to_string(),
570 arguments: serde_json::json!({"text": "{{search.results}}"}),
571 continue_on_error: false,
572 timeout_secs: Some(30),
573 max_retries: Some(2),
574 });
575
576 assert_eq!(skill.inputs.len(), 1);
577 assert_eq!(skill.steps.len(), 2);
578 assert_eq!(skill.steps[1].timeout_secs, Some(30));
579 assert_eq!(skill.steps[1].max_retries, Some(2));
580 }
581
582 #[test]
583 fn test_registry() {
584 let mut registry = SkillRegistry::new();
585 let skill = Skill::new("test", "Test skill");
586 registry.register(skill);
587 assert!(registry.get("test").is_some());
588 assert_eq!(registry.list().len(), 1);
589 }
590
591 #[test]
592 fn test_registry_unregister() {
593 let mut registry = SkillRegistry::new();
594 let skill = Skill::new("test", "Test skill");
595 registry.register(skill);
596 assert!(registry.unregister("test").is_some());
597 assert_eq!(registry.list().len(), 0);
598 }
599
600 #[tokio::test]
601 async fn test_skill_execution() {
602 let transport = MockTransport::new()
603 .with_response(
604 "search",
605 ToolResult::success(serde_json::json!({"results": ["result1", "result2"]})),
606 )
607 .with_response(
608 "summarize",
609 ToolResult::success(serde_json::json!("Summary of results")),
610 );
611
612 let skill = Skill::new("search_and_summarize", "Search and summarize results")
613 .with_input("query")
614 .with_step(SkillStep {
615 name: "search".to_string(),
616 tool: "search".to_string(),
617 arguments: serde_json::json!({"query": "test query"}),
618 continue_on_error: false,
619 timeout_secs: None,
620 max_retries: None,
621 })
622 .with_step(SkillStep {
623 name: "summarize".to_string(),
624 tool: "summarize".to_string(),
625 arguments: serde_json::json!({"text": "summary text"}),
626 continue_on_error: false,
627 timeout_secs: None,
628 max_retries: None,
629 });
630
631 let input_args = HashMap::new();
632
633 let result = skill.execute(&transport, &input_args).await.unwrap();
634 assert!(result.success);
635 assert_eq!(result.step_results.len(), 2);
636 assert!(result.output.is_some());
637 }
638
639 #[tokio::test]
640 async fn test_skill_execution_with_config() {
641 let transport = MockTransport::new().with_response(
642 "test_tool",
643 ToolResult::success(serde_json::json!({"ok": true})),
644 );
645
646 let skill = Skill::new("test_skill", "Test skill").with_step(SkillStep {
647 name: "step1".to_string(),
648 tool: "test_tool".to_string(),
649 arguments: serde_json::json!({}),
650 continue_on_error: false,
651 timeout_secs: None,
652 max_retries: None,
653 });
654
655 let config = ExecutionConfig::new()
656 .with_timeout(TimeoutConfig::new().with_step_timeout(Duration::from_secs(10)))
657 .with_retry(RetryConfig::no_retries());
658
659 let result = skill
660 .execute_with_config(&transport, &HashMap::new(), &config)
661 .await
662 .unwrap();
663
664 assert!(result.success);
665 assert_eq!(result.step_results.len(), 1);
666 }
667
668 #[tokio::test]
669 async fn test_skill_step_timeout() {
670 struct SlowTransport;
672
673 #[async_trait]
674 impl Transport for SlowTransport {
675 async fn connect(&mut self) -> thulp_core::Result<()> {
676 Ok(())
677 }
678 async fn disconnect(&mut self) -> thulp_core::Result<()> {
679 Ok(())
680 }
681 fn is_connected(&self) -> bool {
682 true
683 }
684 async fn list_tools(&self) -> thulp_core::Result<Vec<thulp_core::ToolDefinition>> {
685 Ok(vec![])
686 }
687 async fn call(&self, _call: &ToolCall) -> thulp_core::Result<ToolResult> {
688 tokio::time::sleep(Duration::from_secs(10)).await;
689 Ok(ToolResult::success(serde_json::json!({})))
690 }
691 }
692
693 let skill = Skill::new("slow_skill", "Slow skill").with_step(SkillStep {
694 name: "slow_step".to_string(),
695 tool: "slow_tool".to_string(),
696 arguments: serde_json::json!({}),
697 continue_on_error: false,
698 timeout_secs: None,
699 max_retries: None,
700 });
701
702 let config = ExecutionConfig::new()
703 .with_timeout(TimeoutConfig::new().with_step_timeout(Duration::from_millis(50)))
704 .with_retry(RetryConfig::no_retries());
705
706 let result = skill
707 .execute_with_config(&SlowTransport, &HashMap::new(), &config)
708 .await;
709
710 assert!(result.is_err());
711 match result {
712 Err(SkillError::StepTimeout { step, .. }) => {
713 assert_eq!(step, "slow_step");
714 }
715 _ => panic!("Expected StepTimeout error"),
716 }
717 }
718
719 #[tokio::test]
720 async fn test_skill_step_per_step_timeout_override() {
721 struct SlowTransport;
722
723 #[async_trait]
724 impl Transport for SlowTransport {
725 async fn connect(&mut self) -> thulp_core::Result<()> {
726 Ok(())
727 }
728 async fn disconnect(&mut self) -> thulp_core::Result<()> {
729 Ok(())
730 }
731 fn is_connected(&self) -> bool {
732 true
733 }
734 async fn list_tools(&self) -> thulp_core::Result<Vec<thulp_core::ToolDefinition>> {
735 Ok(vec![])
736 }
737 async fn call(&self, _call: &ToolCall) -> thulp_core::Result<ToolResult> {
738 tokio::time::sleep(Duration::from_millis(100)).await;
739 Ok(ToolResult::success(serde_json::json!({})))
740 }
741 }
742
743 let skill = Skill::new("skill", "Skill with per-step timeout").with_step(SkillStep {
744 name: "step".to_string(),
745 tool: "tool".to_string(),
746 arguments: serde_json::json!({}),
747 continue_on_error: false,
748 timeout_secs: Some(1), max_retries: Some(0),
750 });
751
752 let config = ExecutionConfig::new()
754 .with_timeout(TimeoutConfig::new().with_step_timeout(Duration::from_millis(10)));
755
756 let result = skill
757 .execute_with_config(&SlowTransport, &HashMap::new(), &config)
758 .await;
759
760 assert!(result.is_ok());
762 }
763
764 #[tokio::test]
765 async fn test_skill_continue_on_error() {
766 let transport = MockTransport::new().with_response(
767 "step2",
768 ToolResult::success(serde_json::json!({"ok": true})),
769 );
770
771 let skill = Skill::new("skill", "Skill with continue_on_error")
772 .with_step(SkillStep {
773 name: "step1".to_string(),
774 tool: "missing_tool".to_string(),
775 arguments: serde_json::json!({}),
776 continue_on_error: true, timeout_secs: None,
778 max_retries: Some(0),
779 })
780 .with_step(SkillStep {
781 name: "step2".to_string(),
782 tool: "step2".to_string(),
783 arguments: serde_json::json!({}),
784 continue_on_error: false,
785 timeout_secs: None,
786 max_retries: None,
787 });
788
789 let config = ExecutionConfig::new().with_retry(RetryConfig::no_retries());
790
791 let result = skill
792 .execute_with_config(&transport, &HashMap::new(), &config)
793 .await
794 .unwrap();
795
796 assert!(result.success);
798 assert_eq!(result.step_results.len(), 2);
799
800 let (_, step1_result) = &result.step_results[0];
802 assert!(!step1_result.is_success());
803
804 let (_, step2_result) = &result.step_results[1];
806 assert!(step2_result.is_success());
807 }
808
809 #[test]
810 fn test_skill_step_serialization() {
811 let step = SkillStep {
812 name: "test".to_string(),
813 tool: "tool".to_string(),
814 arguments: serde_json::json!({}),
815 continue_on_error: false,
816 timeout_secs: Some(30),
817 max_retries: Some(2),
818 };
819
820 let json = serde_json::to_string(&step).unwrap();
821 let deserialized: SkillStep = serde_json::from_str(&json).unwrap();
822
823 assert_eq!(deserialized.timeout_secs, Some(30));
824 assert_eq!(deserialized.max_retries, Some(2));
825 }
826
827 #[test]
828 fn test_skill_step_default_optional_fields() {
829 let json = r#"{"name": "test", "tool": "tool", "arguments": {}}"#;
830 let step: SkillStep = serde_json::from_str(json).unwrap();
831
832 assert!(!step.continue_on_error);
833 assert_eq!(step.timeout_secs, None);
834 assert_eq!(step.max_retries, None);
835 }
836}