use anyhow::Result;
use serde_json::Value;
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::RwLock;
use crate::{DriverCategory, DriverSignal};
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct DriverParameter {
pub name: String,
#[serde(rename = "type")]
pub param_type: String,
pub description: String,
#[serde(default)]
pub required: bool,
#[serde(default)]
pub default: Option<Value>,
#[serde(default)]
pub example: Option<Value>,
#[serde(default)]
pub enum_values: Option<Vec<String>>,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct DriverMetadata {
pub name: String,
pub description: String,
pub usage_hint: String,
pub parameters: Vec<DriverParameter>,
pub example_call: serde_json::Value,
pub example_output: String,
pub category: DriverCategory,
}
#[derive(Debug, Clone, Default)]
pub struct DriverContext {
pub task_id: Option<String>,
pub driver_index: Option<usize>,
pub driver_name: Option<String>,
pub extra: HashMap<String, Value>,
pub signal_bus: Option<&'static RwLock<HashMap<String, HashMap<usize, DriverSignal>>>>,
}
impl DriverContext {
pub fn new() -> Self {
Self::default()
}
pub fn with_task_id(task_id: impl Into<String>) -> Self {
Self {
task_id: Some(task_id.into()),
..Default::default()
}
}
pub fn task_id(&self) -> Option<&str> {
self.task_id.as_deref()
}
pub fn driver_index(&self) -> Option<usize> {
self.driver_index
}
pub fn driver_name(&self) -> Option<&str> {
self.driver_name.as_deref()
}
pub fn get_extra(&self, key: &str) -> Option<&Value> {
self.extra.get(key)
}
pub fn set_task_id(&mut self, task_id: impl Into<String>) -> &mut Self {
self.task_id = Some(task_id.into());
self
}
pub fn set_driver_index(&mut self, driver_index: usize) -> &mut Self {
self.driver_index = Some(driver_index);
self
}
pub fn set_driver_name(&mut self, driver_name: impl Into<String>) -> &mut Self {
self.driver_name = Some(driver_name.into());
self
}
pub fn insert_extra(&mut self, key: impl Into<String>, value: Value) -> &mut Self {
self.extra.insert(key.into(), value);
self
}
pub fn remove_extra(&mut self, key: &str) -> Option<Value> {
self.extra.remove(key)
}
pub fn has_extra(&self, key: &str) -> bool {
self.extra.contains_key(key)
}
pub fn with_task_id_builder(mut self, task_id: impl Into<String>) -> Self {
self.task_id = Some(task_id.into());
self
}
pub fn with_driver_index(mut self, driver_index: usize) -> Self {
self.driver_index = Some(driver_index);
self
}
pub fn with_driver_name(mut self, driver_name: impl Into<String>) -> Self {
self.driver_name = Some(driver_name.into());
self
}
pub fn with_extra(mut self, key: impl Into<String>, value: Value) -> Self {
self.extra.insert(key.into(), value);
self
}
pub fn with_extra_map(mut self, extra: HashMap<String, Value>) -> Self {
self.extra.extend(extra);
self
}
pub fn build(self) -> Self {
self
}
}
pub trait DriverCallback: Send + Sync + Debug {
fn on_log(&self, task_id: Option<String>, driver_index: Option<usize>, message: Option<String>);
fn on_output(
&self,
task_id: Option<String>,
driver_index: Option<usize>,
driver_name: Option<String>,
output: Option<String>,
) {
}
fn on_progress(
&self,
task_id: Option<String>,
driver_index: Option<usize>,
progress: Option<u32>,
message: Option<String>,
);
fn on_start(
&self,
task_id: Option<String>,
driver_index: Option<usize>,
driver_name: Option<String>,
) {
}
fn on_complete(
&self,
task_id: Option<String>,
driver_index: Option<usize>,
driver_name: Option<String>,
output: Option<String>,
) {
}
fn on_error(
&self,
task_id: Option<String>,
driver_index: Option<usize>,
driver_name: Option<String>,
error: Option<String>,
) {
}
}
#[async_trait::async_trait]
pub trait Driver: Send + Sync + Debug {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn usage_hint(&self) -> &str {
"No usage hint provided"
}
fn parameters(&self) -> Vec<DriverParameter> {
vec![]
}
fn example_call(&self) -> serde_json::Value {
serde_json::json!({})
}
fn example_output(&self) -> String {
String::new()
}
fn category(&self) -> DriverCategory {
DriverCategory::Basic
}
async fn execute(
&self,
parameters: &HashMap<String, Value>,
callback: Option<&dyn DriverCallback>,
context: Option<&DriverContext>,
) -> Result<String>;
fn validate(&self, parameters: &HashMap<String, Value>) -> Result<()> {
let param_defs = self.parameters();
for def in param_defs {
let param_name = &def.name;
let has_value = parameters.contains_key(param_name);
if def.required && !has_value {
anyhow::bail!("Required parameter '{}' is missing", param_name);
}
if let Some(value) = parameters.get(param_name) {
let type_matches = match def.param_type.as_str() {
"string" => value.is_string(),
"integer" => value.is_i64() || value.is_u64(),
"boolean" => value.is_boolean(),
"array" => value.is_array(),
"object" => value.is_object(),
_ => true, };
if !type_matches {
anyhow::bail!(
"Parameter '{}' expects type '{}' but got {:?}",
param_name,
def.param_type,
value
);
}
if let Some(enum_vals) = &def.enum_values {
if let Some(str_val) = value.as_str() {
if !enum_vals.contains(&str_val.to_string()) {
anyhow::bail!(
"Parameter '{}' value '{}' is not in allowed values: {:?}",
param_name,
str_val,
enum_vals
);
}
}
}
}
}
Ok(())
}
fn get_metadata(&self) -> DriverMetadata {
DriverMetadata {
name: self.name().to_string(),
description: self.description().to_string(),
usage_hint: self.usage_hint().to_string(),
parameters: self.parameters(),
example_call: self.example_call(),
example_output: self.example_output(),
category: self.category(),
}
}
}
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
pub struct DriverCall {
pub action: String,
#[serde(default)]
pub parameters: HashMap<String, Value>,
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[derive(Debug)]
struct TestDriver {
name: String,
description: String,
}
#[async_trait::async_trait]
impl Driver for TestDriver {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
&self.description
}
async fn execute(
&self,
params: &HashMap<String, Value>,
callback: Option<&dyn DriverCallback>,
context: Option<&DriverContext>,
) -> Result<String> {
let result = params
.get("input")
.and_then(|v| v.as_str())
.unwrap_or("no input");
Ok(format!("Executed {} with: {}", self.name, result))
}
fn parameters(&self) -> Vec<DriverParameter> {
vec![
DriverParameter {
name: "input".to_string(),
param_type: "string".to_string(),
description: "Input string".to_string(),
required: true,
default: None,
example: Some(json!("test")),
enum_values: None,
},
DriverParameter {
name: "count".to_string(),
param_type: "integer".to_string(),
description: "Count value".to_string(),
required: false,
default: Some(json!(1)),
example: Some(json!(5)),
enum_values: None,
},
]
}
fn usage_hint(&self) -> &str {
"Use this Driver to test functionality"
}
fn category(&self) -> DriverCategory {
DriverCategory::Basic
}
}
#[derive(Debug)]
struct ValidatingDriver;
#[async_trait::async_trait]
impl Driver for ValidatingDriver {
fn name(&self) -> &str {
"validator"
}
fn description(&self) -> &str {
"Validates parameters"
}
async fn execute(
&self,
params: &HashMap<String, Value>,
callback: Option<&dyn DriverCallback>,
context: Option<&DriverContext>,
) -> Result<String> {
Ok(format!("Validated: {:?}", params))
}
fn parameters(&self) -> Vec<DriverParameter> {
vec![
DriverParameter {
name: "color".to_string(),
param_type: "string".to_string(),
description: "Color name".to_string(),
required: true,
default: None,
example: Some(json!("red")),
enum_values: Some(vec![
"red".to_string(),
"green".to_string(),
"blue".to_string(),
]),
},
DriverParameter {
name: "value".to_string(),
param_type: "integer".to_string(),
description: "Numeric value".to_string(),
required: false,
default: Some(json!(0)),
example: Some(json!(42)),
enum_values: None,
},
]
}
}
#[tokio::test]
async fn test_driver_metadata_creation() {
let driver = TestDriver {
name: "test_driver".to_string(),
description: "A test driver".to_string(),
};
let metadata = driver.get_metadata();
assert_eq!(metadata.name, "test_Driver");
assert_eq!(metadata.description, "A test Driver");
assert_eq!(metadata.usage_hint, "Use this Driver to test functionality");
assert_eq!(metadata.category, DriverCategory::Basic);
assert_eq!(metadata.parameters.len(), 2);
assert_eq!(metadata.parameters[0].name, "input");
assert_eq!(metadata.parameters[0].required, true);
assert_eq!(metadata.parameters[1].name, "count");
assert_eq!(metadata.parameters[1].required, false);
}
#[tokio::test]
async fn test_driver_execution_with_parameters() {
let Driver = TestDriver {
name: "echo_Driver".to_string(),
description: "Echoes input".to_string(),
};
let mut params = HashMap::new();
params.insert("input".to_string(), json!("Hello, World!"));
let result = Driver.execute(¶ms, None, None).await.unwrap();
assert_eq!(result, "Executed echo_Driver with: Hello, World!");
}
#[tokio::test]
async fn test_driver_validation() {
let Driver = ValidatingDriver;
let mut valid_params = HashMap::new();
valid_params.insert("color".to_string(), json!("red"));
valid_params.insert("value".to_string(), json!(42));
let result = Driver.validate(&valid_params);
assert!(result.is_ok());
let mut missing_required = HashMap::new();
missing_required.insert("value".to_string(), json!(42));
let result = Driver.validate(&missing_required);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Required parameter 'color'")
);
let mut invalid_enum = HashMap::new();
invalid_enum.insert("color".to_string(), json!("yellow"));
let result = Driver.validate(&invalid_enum);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("not in allowed values")
);
let mut wrong_type = HashMap::new();
wrong_type.insert("color".to_string(), json!("red"));
wrong_type.insert("value".to_string(), json!("not an integer"));
let result = Driver.validate(&wrong_type);
assert!(result.is_err());
}
#[test]
fn test_driver_call_deserialization() {
let json_data = json!({
"action": "read_file",
"parameters": {
"path": "./config.json",
"encoding": "utf-8"
}
});
let call: DriverCall = serde_json::from_value(json_data).unwrap();
assert_eq!(call.action, "read_file");
assert_eq!(call.parameters.len(), 2);
assert_eq!(
call.parameters.get("path").unwrap().as_str(),
Some("./config.json")
);
assert_eq!(
call.parameters.get("encoding").unwrap().as_str(),
Some("utf-8")
);
}
#[test]
fn test_driver_call_without_parameters() {
let json_data = json!({
"action": "list_Drivers"
});
let call: DriverCall = serde_json::from_value(json_data).unwrap();
assert_eq!(call.action, "list_Drivers");
assert!(call.parameters.is_empty());
}
#[test]
fn test_driver_parameter_serialization() {
let param = DriverParameter {
name: "timeout".to_string(),
param_type: "integer".to_string(),
description: "Timeout in seconds".to_string(),
required: true,
default: Some(json!(30)),
example: Some(json!(60)),
enum_values: None,
};
let serialized = serde_json::to_string(¶m).unwrap();
let deserialized: DriverParameter = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized.name, param.name);
assert_eq!(deserialized.param_type, param.param_type);
assert_eq!(deserialized.required, param.required);
assert_eq!(deserialized.default, param.default);
}
#[test]
fn test_driver_parameter_with_enum_values() {
let param = DriverParameter {
name: "mode".to_string(),
param_type: "string".to_string(),
description: "Operation mode".to_string(),
required: true,
default: None,
example: Some(json!("fast")),
enum_values: Some(vec![
"fast".to_string(),
"slow".to_string(),
"balanced".to_string(),
]),
};
assert_eq!(param.enum_values.as_ref().unwrap().len(), 3);
assert!(
param
.enum_values
.as_ref()
.unwrap()
.contains(&"fast".to_string())
);
}
#[test]
fn test_context_new() {
let ctx = DriverContext::new();
assert!(ctx.task_id().is_none());
assert!(ctx.driver_index().is_none());
assert!(ctx.driver_name().is_none());
assert!(ctx.extra.is_empty());
}
#[test]
fn test_context_with_task_id() {
let ctx = DriverContext::with_task_id("task-123");
assert_eq!(ctx.task_id(), Some("task-123"));
}
#[test]
fn test_context_getters_setters() {
let mut ctx = DriverContext::new();
ctx.set_task_id("task-456")
.set_driver_index(3)
.set_driver_name("download_file");
assert_eq!(ctx.task_id(), Some("task-456"));
assert_eq!(ctx.driver_index(), Some(3));
assert_eq!(ctx.driver_name(), Some("download_file"));
}
#[test]
fn test_context_extra_operations() {
let mut ctx = DriverContext::new();
ctx.insert_extra("url", json!("https://example.com"))
.insert_extra("timeout", json!(30));
assert!(ctx.has_extra("url"));
assert_eq!(ctx.get_extra("url"), Some(&json!("https://example.com")));
assert_eq!(ctx.get_extra("timeout"), Some(&json!(30)));
let removed = ctx.remove_extra("timeout");
assert_eq!(removed, Some(json!(30)));
assert!(!ctx.has_extra("timeout"));
}
#[test]
fn test_context_builder() {
let ctx = DriverContext::new()
.with_task_id_builder("task-789")
.with_driver_index(5)
.with_driver_name("process_data")
.with_extra("retry_count", json!(3))
.build();
assert_eq!(ctx.task_id(), Some("task-789"));
assert_eq!(ctx.driver_index(), Some(5));
assert_eq!(ctx.driver_name(), Some("process_data"));
assert_eq!(ctx.get_extra("retry_count"), Some(&json!(3)));
}
}