use async_trait::async_trait;
use serde_json::Value;
use std::collections::HashMap;
#[derive(Debug)]
pub enum ChainError {
MissingInput(String),
OutputError(String),
ExecutionError(String),
Other(String),
}
impl std::fmt::Display for ChainError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ChainError::MissingInput(key) => write!(f, "缺少输入: {}", key),
ChainError::OutputError(msg) => write!(f, "输出错误: {}", msg),
ChainError::ExecutionError(msg) => write!(f, "执行错误: {}", msg),
ChainError::Other(msg) => write!(f, "Chain 错误: {}", msg),
}
}
}
impl std::error::Error for ChainError {}
pub type ChainResult = HashMap<String, Value>;
#[async_trait]
pub trait BaseChain: Send + Sync {
fn input_keys(&self) -> Vec<&str>;
fn output_keys(&self) -> Vec<&str>;
async fn invoke(&self, inputs: HashMap<String, Value>) -> Result<ChainResult, ChainError>;
fn validate_inputs(&self, inputs: &HashMap<String, Value>) -> Result<(), ChainError> {
for key in self.input_keys() {
if !inputs.contains_key(key) {
return Err(ChainError::MissingInput(key.to_string()));
}
}
Ok(())
}
fn name(&self) -> &str {
"chain"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chain_error_display() {
let error = ChainError::MissingInput("test".to_string());
assert!(error.to_string().contains("缺少输入"));
let error = ChainError::ExecutionError("test".to_string());
assert!(error.to_string().contains("执行错误"));
}
}