Skip to main content

ditto_os/skills/
mod.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::sync::Arc;
4use tokio::sync::RwLock;
5use tracing::{debug, error, info};
6use uuid::Uuid;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct Skill {
10    pub id: String,
11    pub name: String,
12    pub description: String,
13    pub category: SkillCategory,
14    pub parameters: Vec<SkillParameter>,
15    pub steps: Vec<SkillStep>,
16    pub tags: Vec<String>,
17    pub author: String,
18    pub version: String,
19    pub created_at: chrono::DateTime<chrono::Utc>,
20    pub public: bool,
21    pub rating: f32,
22    pub usage_count: u64,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
26pub enum SkillCategory {
27    Navigation,
28    FormFilling,
29    DataExtraction,
30    Testing,
31    Automation,
32    Security,
33    Ecommerce,
34    SocialMedia,
35    General,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct SkillParameter {
40    pub name: String,
41    pub description: String,
42    pub parameter_type: ParameterType,
43    pub required: bool,
44    pub default_value: Option<serde_json::Value>,
45    pub validation: Option<ParameterValidation>,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub enum ParameterType {
50    String,
51    Number,
52    Boolean,
53    Array,
54    Object,
55    Url,
56    Selector,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct ParameterValidation {
61    pub min_length: Option<usize>,
62    pub max_length: Option<usize>,
63    pub pattern: Option<String>,
64    pub allowed_values: Option<Vec<serde_json::Value>>,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct SkillStep {
69    pub id: String,
70    pub name: String,
71    pub action: StepAction,
72    pub parameters: HashMap<String, serde_json::Value>,
73    pub timeout_ms: u32,
74    pub retry_count: u32,
75    pub on_failure: StepFailureAction,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub enum StepAction {
80    Navigate,
81    Click,
82    Type,
83    Wait,
84    Screenshot,
85    ExecuteScript,
86    ExtractText,
87    ExtractAttribute,
88    Scroll,
89    WaitForElement,
90    Condition,
91    Loop,
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub enum StepFailureAction {
96    Stop,
97    Continue,
98    Retry,
99    Skip,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct SkillExecution {
104    pub id: String,
105    pub skill_id: String,
106    pub agent_id: String,
107    pub session_id: String,
108    pub parameters: HashMap<String, serde_json::Value>,
109    pub status: ExecutionStatus,
110    pub started_at: chrono::DateTime<chrono::Utc>,
111    pub completed_at: Option<chrono::DateTime<chrono::Utc>>,
112    pub current_step: Option<String>,
113    pub step_results: HashMap<String, StepResult>,
114    pub error: Option<String>,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
118pub enum ExecutionStatus {
119    Pending,
120    Running,
121    Completed,
122    Failed,
123    Cancelled,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct StepResult {
128    pub step_id: String,
129    pub success: bool,
130    pub data: serde_json::Value,
131    pub error: Option<String>,
132    pub execution_time_ms: u64,
133    pub screenshot: Option<String>,
134}
135
136pub struct SkillEngine {
137    skills: Arc<RwLock<HashMap<String, Skill>>>,
138    executions: Arc<RwLock<HashMap<String, SkillExecution>>>,
139}
140
141impl Default for SkillEngine {
142    fn default() -> Self {
143        Self::new()
144    }
145}
146
147impl SkillEngine {
148    pub fn new() -> Self {
149        Self {
150            skills: Arc::new(RwLock::new(HashMap::new())),
151            executions: Arc::new(RwLock::new(HashMap::new())),
152        }
153    }
154
155    pub async fn register_skill(&self, skill: Skill) -> Result<(), anyhow::Error> {
156        let mut skills = self.skills.write().await;
157
158        if skills.contains_key(&skill.id) {
159            return Err(anyhow::anyhow!("Skill with ID {} already exists", skill.id));
160        }
161
162        skills.insert(skill.id.clone(), skill.clone());
163        info!("Registered skill: {} (ID: {})", skill.name, skill.id);
164
165        Ok(())
166    }
167
168    pub async fn get_skill(&self, skill_id: &str) -> Option<Skill> {
169        let skills = self.skills.read().await;
170        skills.get(skill_id).cloned()
171    }
172
173    pub async fn list_skills(
174        &self,
175        category: Option<SkillCategory>,
176        tags: Option<Vec<String>>,
177    ) -> Vec<Skill> {
178        let skills = self.skills.read().await;
179
180        skills
181            .values()
182            .filter(|skill| {
183                let category_match = category.as_ref().is_none_or(|c| skill.category == *c);
184                let tags_match = tags
185                    .as_ref()
186                    .is_none_or(|t| t.iter().all(|tag| skill.tags.contains(tag)));
187                category_match && tags_match
188            })
189            .cloned()
190            .collect()
191    }
192
193    pub async fn search_skills(&self, query: &str) -> Vec<Skill> {
194        let skills = self.skills.read().await;
195        let query_lower = query.to_lowercase();
196
197        skills
198            .values()
199            .filter(|skill| {
200                skill.name.to_lowercase().contains(&query_lower)
201                    || skill.description.to_lowercase().contains(&query_lower)
202                    || skill
203                        .tags
204                        .iter()
205                        .any(|tag| tag.to_lowercase().contains(&query_lower))
206            })
207            .cloned()
208            .collect()
209    }
210
211    pub async fn execute_skill(
212        &self,
213        skill_id: &str,
214        agent_id: String,
215        session_id: String,
216        parameters: HashMap<String, serde_json::Value>,
217    ) -> Result<String, anyhow::Error> {
218        let skill = {
219            let skills = self.skills.read().await;
220            skills
221                .get(skill_id)
222                .cloned()
223                .ok_or_else(|| anyhow::anyhow!("Skill not found: {}", skill_id))?
224        };
225
226        // Validate parameters
227        self.validate_skill_parameters(&skill, &parameters).await?;
228
229        let execution_id = Uuid::new_v4().to_string();
230
231        let execution = SkillExecution {
232            id: execution_id.clone(),
233            skill_id: skill_id.to_string(),
234            agent_id,
235            session_id,
236            parameters,
237            status: ExecutionStatus::Pending,
238            started_at: chrono::Utc::now(),
239            completed_at: None,
240            current_step: None,
241            step_results: HashMap::new(),
242            error: None,
243        };
244
245        {
246            let mut executions = self.executions.write().await;
247            executions.insert(execution_id.clone(), execution);
248        }
249
250        // Start execution in background
251        let skill_clone = skill.clone();
252        let executions_clone = Arc::clone(&self.executions);
253        let execution_id_for_spawn = execution_id.clone();
254
255        tokio::spawn(async move {
256            if let Err(e) = Self::run_skill_execution(
257                skill_clone,
258                execution_id_for_spawn.clone(),
259                executions_clone,
260            )
261            .await
262            {
263                error!("Skill execution {} failed: {}", execution_id_for_spawn, e);
264            }
265        });
266
267        info!("Started skill execution: {}", execution_id);
268        Ok(execution_id)
269    }
270
271    pub async fn get_execution(&self, execution_id: &str) -> Option<SkillExecution> {
272        let executions = self.executions.read().await;
273        executions.get(execution_id).cloned()
274    }
275
276    pub async fn cancel_execution(&self, execution_id: &str) -> Result<(), anyhow::Error> {
277        let mut executions = self.executions.write().await;
278
279        if let Some(execution) = executions.get_mut(execution_id) {
280            match execution.status {
281                ExecutionStatus::Pending | ExecutionStatus::Running => {
282                    execution.status = ExecutionStatus::Cancelled;
283                    execution.completed_at = Some(chrono::Utc::now());
284                    info!("Cancelled skill execution: {}", execution_id);
285                    Ok(())
286                }
287                _ => Err(anyhow::anyhow!(
288                    "Cannot cancel execution in status: {:?}",
289                    execution.status
290                )),
291            }
292        } else {
293            Err(anyhow::anyhow!("Execution not found: {}", execution_id))
294        }
295    }
296
297    pub async fn list_executions(&self, agent_id: Option<&str>) -> Vec<SkillExecution> {
298        let executions = self.executions.read().await;
299
300        executions
301            .values()
302            .filter(|execution| agent_id.is_none_or(|id| execution.agent_id == id))
303            .cloned()
304            .collect()
305    }
306
307    async fn validate_skill_parameters(
308        &self,
309        skill: &Skill,
310        parameters: &HashMap<String, serde_json::Value>,
311    ) -> Result<(), anyhow::Error> {
312        for param in &skill.parameters {
313            let value = parameters.get(&param.name);
314
315            if param.required && value.is_none() {
316                return Err(anyhow::anyhow!(
317                    "Required parameter '{}' is missing",
318                    param.name
319                ));
320            }
321
322            if let Some(value) = value {
323                self.validate_parameter_value(param, value)?;
324            }
325        }
326
327        Ok(())
328    }
329
330    fn validate_parameter_value(
331        &self,
332        param: &SkillParameter,
333        value: &serde_json::Value,
334    ) -> Result<(), anyhow::Error> {
335        // Type validation
336        match param.parameter_type {
337            ParameterType::String => {
338                if !value.is_string() {
339                    return Err(anyhow::anyhow!(
340                        "Parameter '{}' must be a string",
341                        param.name
342                    ));
343                }
344            }
345            ParameterType::Number => {
346                if !value.is_number() {
347                    return Err(anyhow::anyhow!(
348                        "Parameter '{}' must be a number",
349                        param.name
350                    ));
351                }
352            }
353            ParameterType::Boolean => {
354                if !value.is_boolean() {
355                    return Err(anyhow::anyhow!(
356                        "Parameter '{}' must be a boolean",
357                        param.name
358                    ));
359                }
360            }
361            ParameterType::Array => {
362                if !value.is_array() {
363                    return Err(anyhow::anyhow!(
364                        "Parameter '{}' must be an array",
365                        param.name
366                    ));
367                }
368            }
369            ParameterType::Object => {
370                if !value.is_object() {
371                    return Err(anyhow::anyhow!(
372                        "Parameter '{}' must be an object",
373                        param.name
374                    ));
375                }
376            }
377            _ => {} // Skip validation for URL and Selector for now
378        }
379
380        // Additional validation
381        if let Some(validation) = &param.validation {
382            if let Some(string_value) = value.as_str() {
383                if let Some(min_length) = validation.min_length {
384                    if string_value.len() < min_length {
385                        return Err(anyhow::anyhow!(
386                            "Parameter '{}' must be at least {} characters",
387                            param.name,
388                            min_length
389                        ));
390                    }
391                }
392
393                if let Some(max_length) = validation.max_length {
394                    if string_value.len() > max_length {
395                        return Err(anyhow::anyhow!(
396                            "Parameter '{}' must be at most {} characters",
397                            param.name,
398                            max_length
399                        ));
400                    }
401                }
402
403                if let Some(pattern) = &validation.pattern {
404                    // Simple regex validation (in production, use regex crate)
405                    if !string_value.contains(pattern) {
406                        return Err(anyhow::anyhow!(
407                            "Parameter '{}' does not match required pattern",
408                            param.name
409                        ));
410                    }
411                }
412            }
413        }
414
415        Ok(())
416    }
417
418    async fn run_skill_execution(
419        skill: Skill,
420        execution_id: String,
421        executions: Arc<RwLock<HashMap<String, SkillExecution>>>,
422    ) -> Result<(), anyhow::Error> {
423        // Update status to running
424        {
425            let mut execs = executions.write().await;
426            if let Some(execution) = execs.get_mut(&execution_id) {
427                execution.status = ExecutionStatus::Running;
428            }
429        }
430
431        info!("Running skill execution: {}", execution_id);
432
433        // Execute each step
434        for step in &skill.steps {
435            // Update current step
436            {
437                let mut execs = executions.write().await;
438                if let Some(execution) = execs.get_mut(&execution_id) {
439                    execution.current_step = Some(step.id.clone());
440                }
441            }
442
443            debug!("Executing step: {}", step.name);
444
445            let step_result = Self::execute_skill_step(step).await;
446
447            // Store step result
448            {
449                let mut execs = executions.write().await;
450                if let Some(execution) = execs.get_mut(&execution_id) {
451                    execution
452                        .step_results
453                        .insert(step.id.clone(), step_result.clone());
454
455                    if !step_result.success {
456                        match step.on_failure {
457                            StepFailureAction::Stop => {
458                                execution.status = ExecutionStatus::Failed;
459                                execution.error = step_result.error;
460                                execution.completed_at = Some(chrono::Utc::now());
461                                return Err(anyhow::anyhow!(
462                                    "Skill execution failed at step: {}",
463                                    step.name
464                                ));
465                            }
466                            StepFailureAction::Skip => continue,
467                            StepFailureAction::Retry => {
468                                // Implement retry logic
469                                continue;
470                            }
471                            StepFailureAction::Continue => continue,
472                        }
473                    }
474                }
475            }
476        }
477
478        // Mark as completed
479        {
480            let mut execs = executions.write().await;
481            if let Some(execution) = execs.get_mut(&execution_id) {
482                execution.status = ExecutionStatus::Completed;
483                execution.completed_at = Some(chrono::Utc::now());
484                execution.current_step = None;
485            }
486        }
487
488        info!("Skill execution completed: {}", execution_id);
489        Ok(())
490    }
491
492    async fn execute_skill_step(step: &SkillStep) -> StepResult {
493        let start_time = std::time::Instant::now();
494
495        debug!("Executing step action: {:?}", step.action);
496
497        // Simulate step execution
498        tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
499
500        let execution_time = start_time.elapsed().as_millis() as u64;
501
502        // In a real implementation, this would execute actual browser commands
503        let (success, data, error) = match &step.action {
504            StepAction::Navigate => {
505                if let Some(url) = step.parameters.get("url") {
506                    (true, url.clone(), None)
507                } else {
508                    (
509                        false,
510                        serde_json::Value::Null,
511                        Some("Missing URL parameter".to_string()),
512                    )
513                }
514            }
515            StepAction::Click => {
516                if let Some(selector) = step.parameters.get("selector") {
517                    (true, serde_json::json!({"clicked": selector}), None)
518                } else {
519                    (
520                        false,
521                        serde_json::Value::Null,
522                        Some("Missing selector parameter".to_string()),
523                    )
524                }
525            }
526            StepAction::Type => {
527                if let Some(text) = step.parameters.get("text") {
528                    (true, serde_json::json!({"typed": text}), None)
529                } else {
530                    (
531                        false,
532                        serde_json::Value::Null,
533                        Some("Missing text parameter".to_string()),
534                    )
535                }
536            }
537            StepAction::Screenshot => (
538                true,
539                serde_json::json!({"screenshot": "base64_image_data"}),
540                None,
541            ),
542            _ => (true, serde_json::json!({"result": "success"}), None),
543        };
544
545        StepResult {
546            step_id: step.id.clone(),
547            success,
548            data,
549            error,
550            execution_time_ms: execution_time,
551            screenshot: None,
552        }
553    }
554
555    pub async fn get_skill_stats(&self) -> SkillStats {
556        let skills = self.skills.read().await;
557        let executions = self.executions.read().await;
558
559        let total_skills = skills.len();
560        let public_skills = skills.values().filter(|s| s.public).count();
561
562        let category_counts = {
563            let mut counts = HashMap::new();
564            for skill in skills.values() {
565                *counts.entry(format!("{:?}", skill.category)).or_insert(0) += 1;
566            }
567            counts
568        };
569
570        let total_executions = executions.len();
571        let running_executions = executions
572            .values()
573            .filter(|e| e.status == ExecutionStatus::Running)
574            .count();
575        let completed_executions = executions
576            .values()
577            .filter(|e| e.status == ExecutionStatus::Completed)
578            .count();
579        let failed_executions = executions
580            .values()
581            .filter(|e| e.status == ExecutionStatus::Failed)
582            .count();
583
584        SkillStats {
585            total_skills,
586            public_skills,
587            category_counts,
588            total_executions,
589            running_executions,
590            completed_executions,
591            failed_executions,
592        }
593    }
594
595    // Initialize with some default skills
596    pub async fn init_default_skills(&self) -> Result<(), anyhow::Error> {
597        let default_skills = vec![
598            Skill {
599                id: "navigate-to-url".to_string(),
600                name: "Navigate to URL".to_string(),
601                description: "Navigate to a specific URL".to_string(),
602                category: SkillCategory::Navigation,
603                parameters: vec![SkillParameter {
604                    name: "url".to_string(),
605                    description: "URL to navigate to".to_string(),
606                    parameter_type: ParameterType::Url,
607                    required: true,
608                    default_value: None,
609                    validation: None,
610                }],
611                steps: vec![SkillStep {
612                    id: "step1".to_string(),
613                    name: "Navigate to URL".to_string(),
614                    action: StepAction::Navigate,
615                    parameters: HashMap::from([(
616                        "url".to_string(),
617                        serde_json::Value::String("${url}".to_string()),
618                    )]),
619                    timeout_ms: 10000,
620                    retry_count: 3,
621                    on_failure: StepFailureAction::Stop,
622                }],
623                tags: vec!["navigation".to_string(), "basic".to_string()],
624                author: "Ditto Team".to_string(),
625                version: "1.0.0".to_string(),
626                created_at: chrono::Utc::now(),
627                public: true,
628                rating: 5.0,
629                usage_count: 0,
630            },
631            Skill {
632                id: "take-screenshot".to_string(),
633                name: "Take Screenshot".to_string(),
634                description: "Take a screenshot of the current page".to_string(),
635                category: SkillCategory::Testing,
636                parameters: vec![SkillParameter {
637                    name: "filename".to_string(),
638                    description: "Filename for the screenshot".to_string(),
639                    parameter_type: ParameterType::String,
640                    required: false,
641                    default_value: Some(serde_json::Value::String("screenshot.png".to_string())),
642                    validation: None,
643                }],
644                steps: vec![SkillStep {
645                    id: "step1".to_string(),
646                    name: "Take screenshot".to_string(),
647                    action: StepAction::Screenshot,
648                    parameters: HashMap::from([(
649                        "filename".to_string(),
650                        serde_json::Value::String("${filename}".to_string()),
651                    )]),
652                    timeout_ms: 5000,
653                    retry_count: 2,
654                    on_failure: StepFailureAction::Stop,
655                }],
656                tags: vec!["screenshot".to_string(), "testing".to_string()],
657                author: "Ditto Team".to_string(),
658                version: "1.0.0".to_string(),
659                created_at: chrono::Utc::now(),
660                public: true,
661                rating: 4.5,
662                usage_count: 0,
663            },
664        ];
665
666        for skill in default_skills {
667            self.register_skill(skill).await?;
668        }
669
670        info!("Initialized default skills");
671        Ok(())
672    }
673}
674
675#[derive(Debug, Clone, Serialize, Deserialize)]
676pub struct SkillStats {
677    pub total_skills: usize,
678    pub public_skills: usize,
679    pub category_counts: HashMap<String, usize>,
680    pub total_executions: usize,
681    pub running_executions: usize,
682    pub completed_executions: usize,
683    pub failed_executions: usize,
684}