use async_trait::async_trait;
use crate::error::AiError;
use crate::request::{
GenerateRequest, GenerateResponse, OptimizeRequest, OptimizeResponse, RepairRequest,
RepairResponse, SelectOption, SelectRequest, SelectResponse,
};
#[async_trait]
pub trait AiProvider: Send + Sync {
async fn generate(&self, request: &GenerateRequest)
-> Result<GenerateResponse, AiError>;
async fn repair(&self, request: &RepairRequest) -> Result<RepairResponse, AiError>;
async fn optimize(&self, request: &OptimizeRequest)
-> Result<OptimizeResponse, AiError>;
async fn select(&self, request: &SelectRequest) -> Result<SelectResponse, AiError>;
fn model_id(&self) -> String;
}
pub fn validate_select_response(
options: &[SelectOption],
response: &SelectResponse,
) -> Result<(), AiError> {
if options.iter().any(|o| o.id == response.selected_id) {
Ok(())
} else {
Err(AiError::InvalidResponse(format!(
"selected_id '{}' not in options",
response.selected_id
)))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn opts() -> Vec<SelectOption> {
vec![
SelectOption {
id: "retry".into(),
description: "retry".into(),
},
SelectOption {
id: "fallback".into(),
description: "fallback".into(),
},
]
}
#[test]
fn validate_accepts_id_in_set() {
let options = opts();
let resp = SelectResponse {
selected_id: "retry".into(),
confidence: 0.9,
reasoning: None,
};
validate_select_response(&options, &resp).expect("accepted");
}
#[test]
fn validate_rejects_id_not_in_set() {
let options = opts();
let resp = SelectResponse {
selected_id: "escalate".into(),
confidence: 0.9,
reasoning: None,
};
let err = validate_select_response(&options, &resp).unwrap_err();
match err {
AiError::InvalidResponse(msg) => {
assert!(msg.contains("escalate"), "message missing id: {msg}");
}
other => panic!("expected InvalidResponse, got {other:?}"),
}
}
#[test]
fn validate_rejects_empty_option_set() {
let options: Vec<SelectOption> = Vec::new();
let resp = SelectResponse {
selected_id: "anything".into(),
confidence: 1.0,
reasoning: None,
};
assert!(validate_select_response(&options, &resp).is_err());
}
}