use crate::common::TestResult;
use crate::error::{Error, Result};
use serde::{Deserialize, Serialize};
use std::process::Command;
use std::fs;
use log::info;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AiModelType {
Llama,
Mistral,
GptJ,
Phi,
Custom,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AiConfig {
pub model_type: AiModelType,
pub model_path: Option<String>,
#[serde(default = "default_context_size")]
pub context_size: usize,
#[serde(default = "default_temperature")]
pub temperature: f32,
#[serde(default = "default_max_tokens")]
pub max_tokens: usize,
pub system_prompt: Option<String>,
}
fn default_context_size() -> usize {
2048
}
fn default_temperature() -> f32 {
0.7
}
fn default_max_tokens() -> usize {
1024
}
pub struct AiTestGenerator {
config: AiConfig,
}
impl AiTestGenerator {
pub fn new(config: AiConfig) -> Self {
Self { config }
}
pub async fn generate_test_config(&self, description: &str, test_type: &str) -> Result<String> {
info!("Generating {} test configuration from description", test_type);
let prompt = match test_type {
"api" => self.create_api_test_prompt(description),
"performance" => self.create_performance_test_prompt(description),
"security" => self.create_security_test_prompt(description),
"web" => self.create_web_test_prompt(description),
_ => return Err(Error::ValidationError(format!("Unsupported test type: {}", test_type))),
};
let json_config = self.run_inference(&prompt).await?;
match test_type {
"api" => self.validate_api_test_config(&json_config)?,
"performance" => self.validate_performance_test_config(&json_config)?,
"security" => self.validate_security_test_config(&json_config)?,
"web" => self.validate_web_test_config(&json_config)?,
_ => return Err(Error::ValidationError(format!("Unsupported test type: {}", test_type))),
}
Ok(json_config)
}
pub async fn analyze_test_results(&self, results: &[TestResult]) -> Result<String> {
info!("Analyzing test results using AI");
let results_json = serde_json::to_string_pretty(results)?;
let prompt = self.create_analysis_prompt(&results_json);
self.run_inference(&prompt).await
}
pub async fn suggest_improvements(&self, results: &[TestResult]) -> Result<String> {
info!("Suggesting improvements based on test results");
let results_json = serde_json::to_string_pretty(results)?;
let prompt = self.create_improvement_prompt(&results_json);
self.run_inference(&prompt).await
}
async fn run_inference(&self, prompt: &str) -> Result<String> {
match self.config.model_type {
AiModelType::Llama => self.run_llama_inference(prompt).await,
AiModelType::Mistral => self.run_mistral_inference(prompt).await,
AiModelType::GptJ => self.run_gptj_inference(prompt).await,
AiModelType::Phi => self.run_phi_inference(prompt).await,
AiModelType::Custom => {
if let Some(path) = &self.config.model_path {
self.run_custom_inference(prompt, path).await
} else {
Err(Error::ConfigError("Model path is required for custom models".to_string()))
}
}
}
}
async fn run_llama_inference(&self, prompt: &str) -> Result<String> {
let model_path = self.config.model_path.clone().unwrap_or_else(|| {
"/usr/local/share/models/llama-2-7b-chat.gguf".to_string()
});
let temp_file = tempfile::NamedTempFile::new()?;
fs::write(temp_file.path(), prompt)?;
let output = Command::new("llama-cli")
.arg("--model")
.arg(model_path)
.arg("--ctx-size")
.arg(self.config.context_size.to_string())
.arg("--temp")
.arg(self.config.temperature.to_string())
.arg("--n-predict")
.arg(self.config.max_tokens.to_string())
.arg("--file")
.arg(temp_file.path())
.output()?;
if output.status.success() {
let result = String::from_utf8_lossy(&output.stdout).to_string();
Ok(self.extract_json_from_output(&result))
} else {
let error = String::from_utf8_lossy(&output.stderr).to_string();
Err(Error::TestError(format!("Llama inference failed: {}", error)))
}
}
async fn run_mistral_inference(&self, _prompt: &str) -> Result<String> {
Ok("Mistral inference not yet implemented".to_string())
}
async fn run_gptj_inference(&self, _prompt: &str) -> Result<String> {
Ok("GPT-J inference not yet implemented".to_string())
}
async fn run_phi_inference(&self, _prompt: &str) -> Result<String> {
Ok("Phi inference not yet implemented".to_string())
}
async fn run_custom_inference(&self, _prompt: &str, model_path: &str) -> Result<String> {
Ok(format!("Custom inference with model {} not yet implemented", model_path))
}
fn extract_json_from_output(&self, output: &str) -> String {
if let Some(start) = output.find('{') {
if let Some(end) = output.rfind('}') {
return output[start..=end].to_string();
}
}
output.to_string()
}
fn create_api_test_prompt(&self, description: &str) -> String {
format!(
"Generate a JSON configuration for an API test based on this description: {}\n\
The configuration should include appropriate values for URL, method, headers, \
expected status, and expected response body. Format as valid JSON for a QitOps API test.",
description
)
}
fn create_performance_test_prompt(&self, description: &str) -> String {
format!(
"Generate a JSON configuration for a performance test based on this description: {}\n\
The configuration should include appropriate values for target URL, method, headers, \
success threshold, and ramp-up time. Format as valid JSON for a QitOps performance test.",
description
)
}
fn create_security_test_prompt(&self, description: &str) -> String {
format!(
"Generate a JSON configuration for a security test based on this description: {}\n\
The configuration should include appropriate values for target URL, headers, auth, \
scan types, and maximum high severity findings. Format as valid JSON for a QitOps security test.",
description
)
}
fn create_web_test_prompt(&self, description: &str) -> String {
format!(
"Generate a JSON configuration for a web test based on this description: {}\n\
The configuration should include appropriate values for target URL, viewport, \
wait conditions, assertions, and actions. Format as valid JSON for a QitOps web test.",
description
)
}
fn create_analysis_prompt(&self, results_json: &str) -> String {
format!(
"Analyze these test results and provide insights:\n{}\n\
Focus on patterns, anomalies, and potential issues. \
Format your analysis in markdown with sections for summary, details, and recommendations.",
results_json
)
}
fn create_improvement_prompt(&self, results_json: &str) -> String {
format!(
"Based on these test results, suggest improvements to the tests or the system under test:\n{}\n\
Focus on concrete, actionable suggestions that would improve test coverage, reliability, or performance. \
Format your suggestions in markdown with bullet points.",
results_json
)
}
fn validate_api_test_config(&self, config: &str) -> Result<()> {
if !config.contains("url") || !config.contains("method") {
return Err(Error::ValidationError("Generated API test config is missing required fields".to_string()));
}
Ok(())
}
fn validate_performance_test_config(&self, config: &str) -> Result<()> {
if !config.contains("target_url") || !config.contains("method") {
return Err(Error::ValidationError("Generated performance test config is missing required fields".to_string()));
}
Ok(())
}
fn validate_security_test_config(&self, config: &str) -> Result<()> {
if !config.contains("target_url") {
return Err(Error::ValidationError("Generated security test config is missing required fields".to_string()));
}
Ok(())
}
fn validate_web_test_config(&self, config: &str) -> Result<()> {
if !config.contains("target_url") {
return Err(Error::ValidationError("Generated web test config is missing required fields".to_string()));
}
Ok(())
}
}