llm 1.3.8

A Rust library unifying multiple LLM backends.
Documentation
use std::sync::{
    atomic::{AtomicUsize, Ordering},
    Arc, Mutex,
};

use async_trait::async_trait;

use crate::{
    chat::{ChatMessage, ChatProvider, ChatRole, Tool},
    completion::{CompletionProvider, CompletionRequest, CompletionResponse},
    embedding::EmbeddingProvider,
    error::LLMError,
    memory::{ChatWithMemory, ChatWithMemoryConfig, MemoryProvider, SlidingWindowMemory},
    models::ModelsProvider,
    stt::SpeechToTextProvider,
    tts::TextToSpeechProvider,
    LLMProvider,
};

const RESPONSE_TEXT: &str = "ok";
const TRANSCRIPT: &str = "transcribed";

#[derive(Clone)]
struct RecordingProvider {
    messages: Arc<Mutex<Vec<ChatMessage>>>,
    transcript: String,
    response: String,
    transcribe_calls: Arc<AtomicUsize>,
}

impl RecordingProvider {
    fn new(
        messages: Arc<Mutex<Vec<ChatMessage>>>,
        transcribe_calls: Arc<AtomicUsize>,
        transcript: &str,
        response: &str,
    ) -> Self {
        Self {
            messages,
            transcript: transcript.to_string(),
            response: response.to_string(),
            transcribe_calls,
        }
    }
}

fn unsupported() -> LLMError {
    LLMError::ProviderError("unsupported".to_string())
}

#[async_trait]
impl ChatProvider for RecordingProvider {
    async fn chat_with_tools(
        &self,
        messages: &[ChatMessage],
        _tools: Option<&[Tool]>,
    ) -> Result<Box<dyn crate::chat::ChatResponse>, LLMError> {
        let mut guard = self.messages.lock().expect("messages lock");
        guard.clear();
        guard.extend_from_slice(messages);
        Ok(Box::new(CompletionResponse {
            text: self.response.clone(),
        }))
    }
}

#[async_trait]
impl CompletionProvider for RecordingProvider {
    async fn complete(&self, _req: &CompletionRequest) -> Result<CompletionResponse, LLMError> {
        Err(unsupported())
    }
}

#[async_trait]
impl EmbeddingProvider for RecordingProvider {
    async fn embed(&self, _input: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
        Err(unsupported())
    }
}

#[async_trait]
impl SpeechToTextProvider for RecordingProvider {
    async fn transcribe(&self, _audio: Vec<u8>) -> Result<String, LLMError> {
        self.transcribe_calls.fetch_add(1, Ordering::Relaxed);
        Ok(self.transcript.clone())
    }
}

#[async_trait]
impl TextToSpeechProvider for RecordingProvider {}

#[async_trait]
impl ModelsProvider for RecordingProvider {}

impl LLMProvider for RecordingProvider {}

#[tokio::test]
async fn audio_messages_are_transcribed_before_chat() {
    let messages = Arc::new(Mutex::new(Vec::new()));
    let transcribe_calls = Arc::new(AtomicUsize::new(0));
    let provider = RecordingProvider::new(
        Arc::clone(&messages),
        Arc::clone(&transcribe_calls),
        TRANSCRIPT,
        RESPONSE_TEXT,
    );
    let memory: Arc<tokio::sync::RwLock<Box<dyn MemoryProvider>>> = Arc::new(
        tokio::sync::RwLock::new(Box::new(SlidingWindowMemory::new(5))),
    );
    let config = ChatWithMemoryConfig::new(Arc::new(provider), Arc::clone(&memory));
    let wrapper = ChatWithMemory::with_config(config);

    let msg = ChatMessage::user().audio(vec![1, 2, 3]).build();
    let response = wrapper.chat(&[msg]).await.expect("chat");
    assert_eq!(response.text(), Some(RESPONSE_TEXT.to_string()));
    assert_eq!(transcribe_calls.load(Ordering::Relaxed), 1);

    let recorded = messages.lock().expect("messages lock").clone();
    let recorded_msg = recorded
        .iter()
        .find(|m| matches!(m.role, ChatRole::User))
        .expect("recorded user message");
    assert_eq!(recorded_msg.content, TRANSCRIPT);
    assert!(!recorded_msg.has_audio());

    let stored = memory.read().await.recall("", None).await.expect("recall");
    let stored_msg = stored
        .iter()
        .find(|m| matches!(m.role, ChatRole::User))
        .expect("stored user message");
    assert_eq!(stored_msg.content, TRANSCRIPT);
    assert!(!stored_msg.has_audio());
}