pub mod transformation;
use serde_json::Value;
use std::collections::HashMap;
use tracing::debug;
use crate::core::providers::meta_llama::LlamaConfig;
use crate::core::providers::unified_provider::ProviderError;
use crate::core::types::{chat::ChatRequest, responses::ChatResponse};
pub use transformation::LlamaChatTransformation;
#[derive(Debug, Clone)]
pub struct LlamaChatHandler {
transformation: LlamaChatTransformation,
}
impl LlamaChatHandler {
pub fn new(_config: LlamaConfig) -> Result<Self, ProviderError> {
Ok(Self {
transformation: LlamaChatTransformation::new(),
})
}
pub fn transform_request(&self, request: ChatRequest) -> Result<Value, ProviderError> {
debug!("Transforming chat request for model: {}", request.model);
self.transformation.transform_request(request)
}
pub fn transform_response(&self, response: Value) -> Result<ChatResponse, ProviderError> {
debug!("Transforming Llama response");
self.transformation.transform_response(response)
}
pub fn get_supported_openai_params(&self) -> Vec<String> {
self.transformation.get_supported_params()
}
pub fn validate_request(&self, request: &ChatRequest) -> Result<(), ProviderError> {
if request.model.is_empty() {
return Err(ProviderError::invalid_request("meta", "Model is required"));
}
if request.messages.is_empty() {
return Err(ProviderError::invalid_request(
"meta",
"Messages cannot be empty",
));
}
if let Some(temp) = request.temperature
&& !(0.0..=2.0).contains(&temp)
{
return Err(ProviderError::invalid_request(
"meta",
format!("Temperature must be between 0 and 2, got {}", temp),
));
}
if let Some(top_p) = request.top_p
&& !(0.0..=1.0).contains(&top_p)
{
return Err(ProviderError::invalid_request(
"meta",
format!("top_p must be between 0 and 1, got {}", top_p),
));
}
Ok(())
}
pub fn map_openai_params(
&self,
params: HashMap<String, Value>,
model: &str,
) -> HashMap<String, Value> {
self.transformation.map_openai_params(params, model)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::types::{chat::ChatMessage, message::MessageContent, message::MessageRole};
#[test]
fn test_handler_creation() {
let config = LlamaConfig::default();
let handler = LlamaChatHandler::new(config);
assert!(handler.is_ok());
}
#[test]
fn test_supported_params() {
let config = LlamaConfig {
api_key: "test".to_string(),
..Default::default()
};
let handler = LlamaChatHandler::new(config).unwrap();
let params = handler.get_supported_openai_params();
assert!(params.contains(&"messages".to_string()));
assert!(params.contains(&"model".to_string()));
assert!(params.contains(&"temperature".to_string()));
assert!(params.contains(&"stream".to_string()));
}
#[test]
fn test_request_validation() {
let config = LlamaConfig {
api_key: "test".to_string(),
..Default::default()
};
let handler = LlamaChatHandler::new(config).unwrap();
let valid_request = ChatRequest {
model: "llama3.1-8b".to_string(),
messages: vec![ChatMessage {
role: MessageRole::User,
content: Some(MessageContent::Text("Hello".to_string())),
..Default::default()
}],
temperature: Some(0.8),
top_p: Some(0.9),
..Default::default()
};
assert!(handler.validate_request(&valid_request).is_ok());
let mut invalid_request = valid_request.clone();
invalid_request.temperature = Some(3.0);
assert!(handler.validate_request(&invalid_request).is_err());
let mut invalid_request = valid_request.clone();
invalid_request.top_p = Some(1.5);
assert!(handler.validate_request(&invalid_request).is_err());
let mut invalid_request = valid_request.clone();
invalid_request.messages.clear();
assert!(handler.validate_request(&invalid_request).is_err());
}
}