llm 1.3.8

A Rust library unifying multiple LLM backends.
Documentation
use std::sync::{atomic::AtomicU32, Arc};

use tokio::sync::{broadcast, RwLock};

use crate::{
    chat::{ChatMessage, ChatRole},
    memory::{MemoryProvider, MessageCondition, MessageEvent},
    LLMProvider,
};

pub(super) struct ReactiveConfig {
    memory: Arc<RwLock<Box<dyn MemoryProvider>>>,
    provider: Arc<dyn LLMProvider>,
    role_triggers: Vec<(String, MessageCondition)>,
    role: Option<String>,
    max_cycles: Option<u32>,
    cycle_counter: Arc<AtomicU32>,
}

impl ReactiveConfig {
    pub(super) fn new(
        memory: Arc<RwLock<Box<dyn MemoryProvider>>>,
        provider: Arc<dyn LLMProvider>,
    ) -> Self {
        Self {
            memory,
            provider,
            role_triggers: Vec::new(),
            role: None,
            max_cycles: None,
            cycle_counter: Arc::new(AtomicU32::new(0)),
        }
    }

    pub(super) fn role_triggers(mut self, triggers: Vec<(String, MessageCondition)>) -> Self {
        self.role_triggers = triggers;
        self
    }

    pub(super) fn role(mut self, role: Option<String>) -> Self {
        self.role = role;
        self
    }

    pub(super) fn max_cycles(mut self, max_cycles: Option<u32>) -> Self {
        self.max_cycles = max_cycles;
        self
    }

    pub(super) fn cycle_counter(mut self, counter: Arc<AtomicU32>) -> Self {
        self.cycle_counter = counter;
        self
    }
}

pub(super) fn spawn_reactive_listener(config: ReactiveConfig) {
    tokio::spawn(async move {
        run_reactive(config).await;
    });
}

async fn run_reactive(config: ReactiveConfig) {
    let mut receiver = match get_receiver(&config).await {
        Some(receiver) => receiver,
        None => return,
    };

    while let Ok(event) = receiver.recv().await {
        if !should_handle(&config, &event) || reached_max(&config) {
            continue;
        }
        let context = load_context(&config).await;
        if let Some(text) = generate_reply(&config, &context).await {
            store_reply(&config, text).await;
        }
    }
}

async fn get_receiver(config: &ReactiveConfig) -> Option<broadcast::Receiver<MessageEvent>> {
    let guard = config.memory.read().await;
    guard.get_event_receiver()
}

fn should_handle(config: &ReactiveConfig, event: &MessageEvent) -> bool {
    let event_role = resolve_event_role(event);
    config
        .role_triggers
        .iter()
        .any(|(role, cond)| role == event_role && cond.matches(event))
}

fn resolve_event_role(event: &MessageEvent) -> &str {
    if event.msg.role == ChatRole::User {
        "user"
    } else {
        &event.role
    }
}

fn reached_max(config: &ReactiveConfig) -> bool {
    let Some(max) = config.max_cycles else {
        return false;
    };
    config
        .cycle_counter
        .load(std::sync::atomic::Ordering::Relaxed)
        >= max
}

async fn load_context(config: &ReactiveConfig) -> Vec<ChatMessage> {
    let guard = config.memory.read().await;
    match guard.recall("", None).await {
        Ok(messages) => messages,
        Err(err) => {
            log::warn!("Reactive memory recall error: {err}");
            Vec::new()
        }
    }
}

async fn generate_reply(config: &ReactiveConfig, context: &[ChatMessage]) -> Option<String> {
    let response = match config.provider.chat(context).await {
        Ok(response) => response,
        Err(err) => {
            log::warn!("Reactive chat error: {err}");
            return None;
        }
    };
    response.text()
}

async fn store_reply(config: &ReactiveConfig, text: String) {
    let Some(role) = config.role.clone() else {
        return;
    };
    let msg = ChatMessage::assistant()
        .content(format!("[{role}] {text}"))
        .build();
    let mut guard = config.memory.write().await;
    if let Err(err) = guard.remember_with_role(&msg, role).await {
        log::warn!("Reactive memory save error: {err}");
        return;
    }
    config
        .cycle_counter
        .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}