use anyhow::Result;
use async_trait::async_trait;
use std::collections::HashMap;
use tracing::{info, debug};
use crate::core::CrateContext;
use crate::providers::{OpenAIProvider, GenerationRequest, LegacyLLMProvider};
use super::Stage;
pub struct TestingStage {
openai_provider: OpenAIProvider,
}
impl TestingStage {
pub fn new(openai_provider: OpenAIProvider) -> Self {
Self { openai_provider }
}
async fn generate_unit_tests(&self, file_path: &str, content: &str) -> Result<String> {
let prompt = format!(
"Generate comprehensive unit tests for this Rust code:\n\n{}\n\n\
Include tests for:\n\
- Happy path scenarios\n\
- Edge cases\n\
- Error conditions\n\
- Property-based tests where applicable",
content
);
let request = GenerationRequest {
prompt,
context: format!("Testing file: {}", file_path),
max_tokens: Some(2048),
temperature: Some(0.4),
model: None,
};
let response = self.openai_provider.generate(request).await?;
Ok(response.content)
}
async fn generate_integration_tests(&self, context: &CrateContext) -> Result<String> {
let prompt = format!(
"Generate integration tests for the Rust crate '{}' with description: '{}'. \
Test the public API and main use cases.",
context.spec.name,
context.spec.description
);
let request = GenerationRequest {
prompt,
context: serde_json::to_string(context)?,
max_tokens: Some(2048),
temperature: Some(0.4),
model: None,
};
let response = self.openai_provider.generate(request).await?;
Ok(response.content)
}
async fn generate_benchmark_tests(&self, context: &CrateContext) -> Result<String> {
let prompt = format!(
"Generate benchmark tests for the Rust crate '{}'. \
Focus on performance-critical functions and measure key metrics.",
context.spec.name
);
let request = GenerationRequest {
prompt,
context: serde_json::to_string(context)?,
max_tokens: Some(1536),
temperature: Some(0.4),
model: None,
};
let response = self.openai_provider.generate(request).await?;
Ok(response.content)
}
async fn add_test_files(&self, context: &CrateContext) -> Result<HashMap<String, String>> {
let mut files_with_tests = context.file_structure.clone();
for (path, content) in &context.file_structure {
if path.ends_with(".rs") && path.starts_with("src/") && !path.contains("test") {
let test_content = self.generate_unit_tests(path, content).await?;
let enhanced_content = format!("{}\n\n#[cfg(test)]\nmod tests {{\n use super::*;\n\n{}\n}}",
content, test_content);
files_with_tests.insert(path.clone(), enhanced_content);
}
}
let integration_tests = self.generate_integration_tests(context).await?;
files_with_tests.insert("tests/integration_test.rs".to_string(), integration_tests);
let benchmark_tests = self.generate_benchmark_tests(context).await?;
files_with_tests.insert("benches/benchmark.rs".to_string(), benchmark_tests);
let test_config = self.generate_test_config(context).await?;
files_with_tests.insert(".github/workflows/test.yml".to_string(), test_config);
Ok(files_with_tests)
}
async fn generate_test_config(&self, context: &CrateContext) -> Result<String> {
let config = format!(
r#"name: Tests
// TODO: document this
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
// TODO: document this
env:
CARGO_TERM_COLOR: always
// TODO: document this
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Install Rust
uses: actions-rs/toolchain@v1
with:
toolchain: stable
override: true
components: rustfmt, clippy
- name: Cache cargo registry
uses: actions/cache@v3
with:
path: ~/.cargo/registry
key: ${{{{ runner.os }}}}-cargo-registry-${{{{ hashFiles('**/Cargo.lock') }}}}
- name: Cache cargo index
uses: actions/cache@v3
with:
path: ~/.cargo/git
key: ${{{{ runner.os }}}}-cargo-index-${{{{ hashFiles('**/Cargo.lock') }}}}
- name: Cache cargo build
uses: actions/cache@v3
with:
path: target
key: ${{{{ runner.os }}}}-cargo-build-target-${{{{ hashFiles('**/Cargo.lock') }}}}
- name: Check formatting
run: cargo fmt -- --check
- name: Run clippy
run: cargo clippy -- -D warnings
- name: Run tests
run: cargo test --verbose
- name: Run integration tests
run: cargo test --test integration_test
- name: Run benchmarks
run: cargo bench
"#,
);
Ok(config)
}
}
#[async_trait]
impl Stage for TestingStage {
type Input = CrateContext;
type Output = CrateContext;
async fn execute(&self, input: &Self::Input) -> Result<Self::Output> {
info!("Starting test generation for crate: {}", input.spec.name);
let files_with_tests = self.add_test_files(input).await?;
let mut tested_context = input.clone();
tested_context.file_structure = files_with_tests;
tested_context.metadata.insert(
"tests_generated_at".to_string(),
chrono::Utc::now().to_rfc3339()
);
tested_context.metadata.insert(
"test_coverage".to_string(),
"comprehensive".to_string()
);
debug!("Test generation complete for crate: {}", input.spec.name);
Ok(tested_context)
}
fn name(&self) -> &str {
"testing"
}
}