alith-client 0.4.3

The Easiest Rust Interface for Local LLMs, and an Interface for Deterministic Signals from Probabilistic LLM Vibes
Documentation
pub mod round;
pub mod step;

use alith_interface::requests::{
    completion::{CompletionFinishReason, CompletionRequest},
    stop_sequence::StoppingSequence,
};
use anyhow::{Result, anyhow};
use core::panic;
pub use round::CascadeRound;
use step::InferenceStep;

#[derive(Clone)]
pub struct CascadeFlow {
    pub cascade_name: String,
    pub duration: std::time::Duration,
    pub result_can_be_none: bool,
    pub rounds: Vec<CascadeRound>,
    pub start_time: std::time::Instant,
}

impl CascadeFlow {
    pub fn new<T: Into<String>>(cascade_name: T) -> Self {
        Self {
            cascade_name: cascade_name.into(),
            start_time: std::time::Instant::now(),
            duration: std::time::Duration::default(),
            rounds: Vec::new(),
            result_can_be_none: false,
        }
    }

    pub fn new_round<T: Into<String>>(&mut self, task: T) -> &mut CascadeRound {
        let round = CascadeRound::new(task);
        self.rounds.push(round);
        self.rounds.last_mut().unwrap()
    }

    pub fn add_round(&mut self, round: CascadeRound) {
        self.rounds.push(round);
    }

    pub async fn run_all_rounds(&mut self, base_req: &mut CompletionRequest) -> Result<()> {
        self.start_time = std::time::Instant::now();

        for round in self.rounds.iter_mut() {
            round.run_all_steps(base_req).await?;
        }

        self.duration = self.start_time.elapsed();
        Ok(())
    }

    pub fn last_round(&mut self) -> Result<&mut CascadeRound> {
        match self.rounds.last_mut() {
            Some(round) => Ok(round),
            None => Err(anyhow!("No rounds in cascade")),
        }
    }

    pub fn drop_last_round(&mut self) -> crate::Result<()> {
        match self.rounds.pop() {
            Some(..) => Ok(()),
            None => crate::bail!("No rounds in cascade"),
        }
    }

    pub fn open_cascade(&mut self) {
        self.start_time = std::time::Instant::now();
    }

    pub fn close_cascade(&mut self) -> Result<()> {
        self.duration = self.start_time.elapsed();
        Ok(())
    }

    pub fn primitive_result(&self) -> Option<String> {
        match self.rounds.last() {
            Some(round) => round.primitive_result(),
            None => panic!("No rounds in cascade"),
        }
    }
}

pub(crate) async fn cascade_request(
    base_req: &mut CompletionRequest,
    step: &mut InferenceStep,
) -> Result<()> {
    let res = base_req.request().await?;
    if matches!(
        res.finish_reason,
        CompletionFinishReason::MatchingStoppingSequence(StoppingSequence::NoResult(_))
    ) {
        step.llm_content = None;
        return Ok(());
    }

    match step.step_config.grammar.validate_clean(&res.content) {
        Ok(content) => {
            step.llm_content = Some(content.clone());
        }
        Err(e) => {
            crate::info!(?e);
        }
    }
    Ok(())
}

impl std::fmt::Display for CascadeFlow {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        writeln!(f)?;
        writeln!(f, "\x1b[1m\x1B[38;2;92;244;37m{}\x1b[0m", self.cascade_name)?;
        writeln!(f)?;
        for (i, round) in self.rounds.iter().enumerate() {
            let color = ROUND_GRADIENT[i % ROUND_GRADIENT.len()];
            writeln!(f, "\x1b[1m{color}Round {}\x1b[0m", i + 1)?;
            writeln!(f, "{round}",)?;
        }
        Ok(())
    }
}
static ROUND_GRADIENT: std::sync::LazyLock<Vec<&'static str>> = std::sync::LazyLock::new(|| {
    vec![
        "\x1B[38;2;230;175;45m",
        "\x1B[38;2;235;158;57m",
        "\x1B[38;2;235;142;68m",
        "\x1B[38;2;232;127;80m",
        "\x1B[38;2;226;114;91m",
        "\x1B[38;2;216;103;100m",
        "\x1B[38;2;204;94;108m",
        "\x1B[38;2;189;88;114m",
        "\x1B[38;2;172;83;118m",
        "\x1B[38;2;153;79;119m",
        "\x1B[38;2;134;76;118m",
        "\x1B[38;2;115;73;114m",
        "\x1B[38;2;97;69;108m",
        "\x1B[38;2;80;65;99m",
        "\x1B[38;2;65;60;88m",
    ]
});