use std::sync::Arc;
use async_trait::async_trait;
use serde_json::Value;
use crate::EvolutionResult;
#[derive(Clone, Debug)]
pub struct CompletionOptions {
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub model: Option<String>,
}
impl Default for CompletionOptions {
fn default() -> Self {
Self {
max_tokens: Some(1000),
temperature: Some(0.7),
model: None,
}
}
}
#[async_trait]
pub trait LlmAdapter: Send + Sync {
fn provider_name(&self) -> String;
async fn complete(&self, prompt: &str, options: CompletionOptions) -> EvolutionResult<String>;
async fn chat_complete(
&self,
messages: Vec<Value>,
options: CompletionOptions,
) -> EvolutionResult<String>;
async fn list_models(&self) -> EvolutionResult<Vec<String>>;
}
pub struct MockLlmAdapter {
provider_name: String,
}
impl MockLlmAdapter {
pub fn new(provider_name: &str) -> Self {
Self {
provider_name: provider_name.to_string(),
}
}
}
#[async_trait]
impl LlmAdapter for MockLlmAdapter {
fn provider_name(&self) -> String {
self.provider_name.clone()
}
async fn complete(&self, prompt: &str, _options: CompletionOptions) -> EvolutionResult<String> {
if prompt.is_empty() {
return Err(crate::error::EvolutionError::InvalidInput(
"Prompt cannot be empty".to_string(),
));
}
if prompt.len() > 100_000 {
return Err(crate::error::EvolutionError::InvalidInput(
"Prompt too long (max 100,000 characters)".to_string(),
));
}
let suspicious_patterns = [
"ignore previous instructions",
"system:",
"assistant:",
"user:",
"###",
"---END---",
"<|im_start|>",
"<|im_end|>",
];
let prompt_lower = prompt.to_lowercase();
for pattern in &suspicious_patterns {
if prompt_lower.contains(pattern) {
log::warn!("Potential prompt injection detected: {}", pattern);
break;
}
}
if (prompt.contains("Rate the quality") && prompt.contains("0.0 to 1.0"))
|| (prompt.to_lowercase().contains("overall quality score")
&& prompt.contains("0.0 to 1.0"))
{
if prompt.contains("2+2") {
return Ok("0.95".to_string()); } else if prompt.contains("Analyze") {
return Ok("0.75".to_string()); } else {
if prompt.contains("Respond with only the numerical score") {
return Ok("0.85".to_string());
} else {
return Ok("Overall quality score: 0.85".to_string()); }
}
}
let keywords: Vec<&str> = prompt
.split_whitespace()
.filter(|word| word.len() > 3)
.take(5)
.collect();
Ok(format!(
"Analysis of {}: Based on the request about {}, here's a detailed response covering these aspects.",
keywords.join(", "),
prompt.chars().take(100).collect::<String>()
))
}
async fn chat_complete(
&self,
messages: Vec<Value>,
options: CompletionOptions,
) -> EvolutionResult<String> {
let prompt = messages
.iter()
.filter_map(|msg| msg.get("content").and_then(|c| c.as_str()))
.collect::<Vec<_>>()
.join("\n");
self.complete(&prompt, options).await
}
async fn list_models(&self) -> EvolutionResult<Vec<String>> {
Ok(vec![
"mock-gpt-4".to_string(),
"mock-claude-3".to_string(),
"mock-llama-2".to_string(),
])
}
}
pub struct LlmAdapterFactory;
impl LlmAdapterFactory {
pub fn create_mock(provider: &str) -> Arc<dyn LlmAdapter> {
Arc::new(MockLlmAdapter::new(provider))
}
pub fn from_config(
provider: &str,
_model: &str,
_config: Option<Value>,
) -> EvolutionResult<Arc<dyn LlmAdapter>> {
if provider.is_empty() {
return Err(crate::error::EvolutionError::InvalidInput(
"Provider name cannot be empty".to_string(),
));
}
if provider.len() > 100 {
return Err(crate::error::EvolutionError::InvalidInput(
"Provider name too long (max 100 characters)".to_string(),
));
}
if !provider
.chars()
.all(|c| c.is_alphanumeric() || c == '_' || c == '-')
{
return Err(crate::error::EvolutionError::InvalidInput(
"Provider name contains invalid characters".to_string(),
));
}
Ok(Self::create_mock(provider))
}
pub fn create_specialized_agent(
provider: &str,
_model: &str,
_preamble: &str,
) -> EvolutionResult<Arc<dyn LlmAdapter>> {
Ok(Self::create_mock(provider))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_completion_options_default() {
let opts = CompletionOptions::default();
assert_eq!(opts.max_tokens, Some(1000));
assert_eq!(opts.temperature, Some(0.7));
assert!(opts.model.is_none());
}
#[test]
fn test_factory_create_mock() {
let adapter = LlmAdapterFactory::create_mock("test");
assert_eq!(adapter.provider_name(), "test");
}
#[tokio::test]
async fn test_mock_adapter_complete() {
let adapter = MockLlmAdapter::new("test");
let result = adapter
.complete("test prompt", CompletionOptions::default())
.await;
assert!(result.is_ok());
assert!(result.unwrap().contains("Analysis of"));
}
}