use crate::{
provider::{LlmError, Provider},
types::{ChatCompletionRequest, ChatCompletionResponse, Message as LlmMessage},
};
pub struct RigAdapter<P: Provider> {
provider: P,
model: String,
}
impl<P: Provider + Clone> RigAdapter<P> {
pub fn new(provider: P, model: impl Into<String>) -> Self {
Self {
provider,
model: model.into(),
}
}
pub fn completion(&self) -> RigCompletionBuilder<P> {
RigCompletionBuilder::new(self.provider.clone(), self.model.clone())
}
}
pub struct RigCompletionBuilder<P: Provider> {
provider: P,
model: String,
messages: Vec<LlmMessage>,
temperature: Option<f32>,
max_tokens: Option<u32>,
}
impl<P: Provider> RigCompletionBuilder<P> {
fn new(provider: P, model: String) -> Self {
Self {
provider,
model,
messages: Vec::new(),
temperature: None,
max_tokens: None,
}
}
pub fn system(mut self, content: impl Into<String>) -> Self {
self.messages.push(LlmMessage::System {
content: content.into(),
name: None,
});
self
}
pub fn user(mut self, content: impl Into<String>) -> Self {
self.messages.push(LlmMessage::User {
content: content.into(),
name: None,
});
self
}
pub fn temperature(mut self, temp: f32) -> Self {
self.temperature = Some(temp);
self
}
pub fn max_tokens(mut self, max: u32) -> Self {
self.max_tokens = Some(max);
self
}
pub async fn send(self) -> Result<RigCompletion, LlmError> {
let request = ChatCompletionRequest {
model: self.model,
messages: self.messages,
temperature: self.temperature,
max_tokens: self.max_tokens,
stream: Some(false),
top_p: None,
frequency_penalty: None,
presence_penalty: None,
stop: None,
user: None,
tools: None,
tool_choice: None,
};
let response = self.provider.chat_completion(request).await?;
Ok(RigCompletion::from(response))
}
}
pub struct RigCompletion {
pub content: String,
pub model: String,
pub usage: Option<crate::types::Usage>,
}
impl From<ChatCompletionResponse> for RigCompletion {
fn from(response: ChatCompletionResponse) -> Self {
let content = response
.choices
.first()
.and_then(|choice| match &choice.message {
LlmMessage::Assistant { content, .. } => content.clone(),
_ => None,
})
.unwrap_or_default();
Self {
content,
model: response.model,
usage: response.usage,
}
}
}
impl std::fmt::Display for RigCompletion {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.content)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::provider::{LlmError, Provider};
use crate::types::{
ChatCompletionRequest, ChatCompletionResponse, EmbeddingRequest, EmbeddingResponse,
};
#[derive(Clone, Debug)]
struct MockProvider;
#[async_trait::async_trait]
impl Provider for MockProvider {
async fn chat_completion(
&self,
_request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, LlmError> {
Ok(ChatCompletionResponse {
id: "test".to_string(),
object: "chat.completion".to_string(),
created: 0,
model: "test-model".to_string(),
choices: vec![crate::types::Choice {
index: 0,
message: LlmMessage::Assistant {
content: Some("Test response".to_string()),
refusal: None,
tool_calls: None,
},
finish_reason: Some("stop".to_string()),
}],
usage: None,
})
}
async fn embeddings(
&self,
_request: EmbeddingRequest,
) -> Result<EmbeddingResponse, LlmError> {
unimplemented!()
}
fn provider_name(&self) -> &'static str {
"mock"
}
}
#[tokio::test]
async fn test_rig_adapter() {
let adapter = RigAdapter::new(MockProvider, "test-model");
let completion = adapter
.completion()
.system("Test system")
.user("Test user")
.send()
.await;
assert!(completion.is_ok());
assert_eq!(completion.unwrap().content, "Test response");
}
}