adversaria 0.1.0

Adversarial Testing Harness for Large Language Models
Documentation
use crate::core::{
    config::ProviderConfig,
    error::{AdversariaError, Result},
    ModelResponse,
};
use crate::providers::Provider;
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::time::Duration;

pub struct OllamaProvider {
    client: Client,
    api_base: String,
    model: String,
    timeout: Duration,
}

#[derive(Debug, Serialize)]
struct OllamaRequest {
    model: String,
    prompt: String,
    stream: bool,
}

#[derive(Debug, Deserialize)]
struct OllamaResponse {
    response: String,
    model: String,
}

impl OllamaProvider {
    pub fn new(config: ProviderConfig) -> Result<Self> {
        let api_base = config
            .api_base
            .unwrap_or_else(|| "http://localhost:11434".to_string());

        let timeout = Duration::from_secs(config.timeout_seconds.unwrap_or(60));

        let client = Client::builder().timeout(timeout).build().map_err(|e| {
            AdversariaError::Provider(format!("Failed to create HTTP client: {}", e))
        })?;

        Ok(Self {
            client,
            api_base,
            model: config.model,
            timeout,
        })
    }
}

#[async_trait]
impl Provider for OllamaProvider {
    fn name(&self) -> &str {
        "ollama"
    }

    fn model(&self) -> &str {
        &self.model
    }

    async fn generate(&self, prompt: &str) -> Result<ModelResponse> {
        let url = format!("{}/api/generate", self.api_base);

        let request = OllamaRequest {
            model: self.model.clone(),
            prompt: prompt.to_string(),
            stream: false,
        };

        let response = self
            .client
            .post(&url)
            .header("Content-Type", "application/json")
            .json(&request)
            .send()
            .await?;

        if !response.status().is_success() {
            let status = response.status();
            let error_text = response.text().await.unwrap_or_default();
            return Err(AdversariaError::Provider(format!(
                "Ollama API error ({}): {}",
                status, error_text
            )));
        }

        let ollama_response: OllamaResponse = response.json().await?;

        Ok(ModelResponse {
            content: ollama_response.response,
            model: ollama_response.model,
            usage: None,
        })
    }

    async fn health_check(&self) -> Result<bool> {
        let url = format!("{}/api/tags", self.api_base);

        let response = self.client.get(&url).send().await?;

        Ok(response.status().is_success())
    }
}