use crate::error::{Error, Result};
use std::future::Future;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct LmOutput {
pub text: Arc<str>,
pub prompt_tokens: u32,
pub completion_tokens: u32,
}
impl LmOutput {
pub fn new(text: impl Into<Arc<str>>) -> Self {
Self {
text: text.into(),
prompt_tokens: 0,
completion_tokens: 0,
}
}
pub fn with_tokens(text: impl Into<Arc<str>>, prompt: u32, completion: u32) -> Self {
Self {
text: text.into(),
prompt_tokens: prompt,
completion_tokens: completion,
}
}
pub fn total_tokens(&self) -> u32 {
self.prompt_tokens + self.completion_tokens
}
}
pub trait Llm: Send + Sync {
type GenerateFut<'a>: Future<Output = Result<LmOutput>> + Send + 'a
where
Self: 'a;
fn generate<'a>(
&'a self,
prompt: &'a str,
context: &'a str,
feedback: Option<&'a str>,
) -> Self::GenerateFut<'a>;
fn model_name(&self) -> &str {
"unknown"
}
fn max_context(&self) -> usize {
4096
}
}
pub struct MockLlm<F>
where
F: Fn(&str, Option<&str>) -> String + Send + Sync,
{
generator: F,
name: &'static str,
}
impl<F> MockLlm<F>
where
F: Fn(&str, Option<&str>) -> String + Send + Sync,
{
pub fn new(generator: F) -> Self {
Self {
generator,
name: "mock",
}
}
pub fn with_name(mut self, name: &'static str) -> Self {
self.name = name;
self
}
}
impl<F> Llm for MockLlm<F>
where
F: Fn(&str, Option<&str>) -> String + Send + Sync,
{
type GenerateFut<'a>
= std::future::Ready<Result<LmOutput>>
where
Self: 'a;
fn generate<'a>(
&'a self,
prompt: &'a str,
_context: &'a str,
feedback: Option<&'a str>,
) -> Self::GenerateFut<'a> {
let text = (self.generator)(prompt, feedback);
std::future::ready(Ok(LmOutput::new(text)))
}
fn model_name(&self) -> &str {
self.name
}
}
pub struct IterativeMockLlm<F>
where
F: Fn(u32, &str, Option<&str>) -> String + Send + Sync,
{
generator: F,
iteration: std::sync::atomic::AtomicU32,
name: &'static str,
}
impl<F> IterativeMockLlm<F>
where
F: Fn(u32, &str, Option<&str>) -> String + Send + Sync,
{
pub fn new(generator: F) -> Self {
Self {
generator,
iteration: std::sync::atomic::AtomicU32::new(0),
name: "iterative_mock",
}
}
pub fn with_name(mut self, name: &'static str) -> Self {
self.name = name;
self
}
pub fn reset(&self) {
self.iteration.store(0, std::sync::atomic::Ordering::SeqCst);
}
pub fn current_iteration(&self) -> u32 {
self.iteration.load(std::sync::atomic::Ordering::SeqCst)
}
}
impl<F> Llm for IterativeMockLlm<F>
where
F: Fn(u32, &str, Option<&str>) -> String + Send + Sync,
{
type GenerateFut<'a>
= std::future::Ready<Result<LmOutput>>
where
Self: 'a;
fn generate<'a>(
&'a self,
prompt: &'a str,
_context: &'a str,
feedback: Option<&'a str>,
) -> Self::GenerateFut<'a> {
let iteration = self
.iteration
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
let text = (self.generator)(iteration, prompt, feedback);
std::future::ready(Ok(LmOutput::new(text)))
}
fn model_name(&self) -> &str {
self.name
}
}
#[derive(Debug, Clone)]
pub struct FailingLlm {
message: String,
}
impl FailingLlm {
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
}
}
}
impl Llm for FailingLlm {
type GenerateFut<'a>
= std::future::Ready<Result<LmOutput>>
where
Self: 'a;
fn generate<'a>(
&'a self,
_prompt: &'a str,
_context: &'a str,
_feedback: Option<&'a str>,
) -> Self::GenerateFut<'a> {
std::future::ready(Err(Error::module(&self.message)))
}
fn model_name(&self) -> &str {
"failing"
}
}
fn which_claude() -> Option<String> {
if let Ok(output) = std::process::Command::new("which").arg("claude").output() {
if output.status.success() {
let path = String::from_utf8_lossy(&output.stdout).trim().to_string();
if !path.is_empty() {
return Some(path);
}
}
}
let common_paths = ["/usr/local/bin/claude", "/opt/homebrew/bin/claude"];
for path in &common_paths {
if std::path::Path::new(path).exists() {
return Some(path.to_string());
}
}
None
}
pub struct CliLlm {
path: String,
}
impl CliLlm {
pub fn new() -> Result<Self> {
let path = which_claude()
.ok_or_else(|| Error::module("claude CLI not found. Install the CLI tool first."))?;
Ok(Self { path })
}
pub fn with_path(path: impl Into<String>) -> Self {
Self { path: path.into() }
}
}
impl Llm for CliLlm {
type GenerateFut<'a> = std::future::Ready<Result<LmOutput>>;
fn generate<'a>(
&'a self,
prompt: &'a str,
context: &'a str,
feedback: Option<&'a str>,
) -> Self::GenerateFut<'a> {
let mut combined = String::new();
if !context.is_empty() {
combined.push_str(context);
combined.push_str("\n\n");
}
combined.push_str(prompt);
if let Some(fb) = feedback {
combined.push_str("\n\n[Previous attempt feedback: ");
combined.push_str(fb);
combined.push(']');
}
let result = std::process::Command::new(&self.path)
.args(["-p", &combined, "--output-format", "text"])
.output();
let output = match result {
Ok(o) => o,
Err(e) => {
return std::future::ready(Err(Error::module(format!(
"Failed to execute claude CLI: {}",
e
))));
}
};
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return std::future::ready(Err(Error::module(format!(
"claude CLI failed: {}",
stderr
))));
}
let text = match String::from_utf8(output.stdout) {
Ok(s) => s.trim().to_string(),
Err(e) => {
return std::future::ready(Err(Error::module(format!(
"claude CLI output is not valid UTF-8: {}",
e
))));
}
};
let word_count = text.split_whitespace().count() as u32;
let prompt_word_count = combined.split_whitespace().count() as u32;
let est_prompt_tokens = (prompt_word_count as f64 * 1.3) as u32;
let est_completion_tokens = (word_count as f64 * 1.3) as u32;
std::future::ready(Ok(LmOutput::with_tokens(
text,
est_prompt_tokens,
est_completion_tokens,
)))
}
fn model_name(&self) -> &str {
"claude-code"
}
fn max_context(&self) -> usize {
200_000
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_mock_llm() {
let llm = MockLlm::new(|prompt, _| format!("Response: {}", prompt));
let output = llm.generate("test prompt", "", None).await.unwrap();
assert_eq!(&*output.text, "Response: test prompt");
}
#[tokio::test]
async fn test_mock_llm_with_feedback() {
let llm = MockLlm::new(|prompt, feedback| match feedback {
Some(fb) => format!("Improved: {} (feedback: {})", prompt, fb),
None => format!("Initial: {}", prompt),
});
let output = llm.generate("test", "", None).await.unwrap();
assert!(output.text.starts_with("Initial:"));
let output = llm.generate("test", "", Some("do better")).await.unwrap();
assert!(output.text.starts_with("Improved:"));
assert!(output.text.contains("do better"));
}
#[tokio::test]
async fn test_iterative_mock_llm() {
let llm = IterativeMockLlm::new(|iter, _prompt, _| match iter {
0 => "first try".to_string(),
1 => "second try".to_string(),
_ => "final answer".to_string(),
});
let out1 = llm.generate("test", "", None).await.unwrap();
assert_eq!(&*out1.text, "first try");
let out2 = llm.generate("test", "", Some("improve")).await.unwrap();
assert_eq!(&*out2.text, "second try");
let out3 = llm.generate("test", "", Some("more")).await.unwrap();
assert_eq!(&*out3.text, "final answer");
}
#[tokio::test]
async fn test_failing_llm() {
let llm = FailingLlm::new("intentional failure");
let result = llm.generate("test", "", None).await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("intentional failure"));
}
#[test]
fn test_lm_output() {
let output = LmOutput::new("test");
assert_eq!(&*output.text, "test");
assert_eq!(output.total_tokens(), 0);
let output = LmOutput::with_tokens("test", 10, 20);
assert_eq!(output.total_tokens(), 30);
}
}