lemon_llm/
replicate.rs

1use std::collections::HashMap;
2
3use replicate_rust::{config::Config, Replicate};
4
5use crate::{GenerateError, LlmBackend};
6
7pub struct ReplicateBackend {
8    pub model: ReplicateModel,
9    config: Config,
10}
11
12#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
13pub enum ReplicateModel {
14    Llama2,
15    #[default]
16    Mistral7B,
17}
18
19impl ReplicateModel {
20    pub fn as_str(&self) -> &str {
21        match self {
22            Self::Llama2 => "meta/llama-2-7b:73001d654114dad81ec65da3b834e2f691af1e1526453189b7bf36fb3f32d0f9",
23            Self::Mistral7B => "mistralai/mistral-7b-instruct-v0.1:83b6a56e7c828e667f21fd596c338fd4f0039b46bcfa18d973e8e70e455fda70",
24        }
25    }
26}
27
28impl ReplicateBackend {
29    pub fn new(model: ReplicateModel, config: Config) -> Self {
30        Self { model, config }
31    }
32}
33
34impl LlmBackend for ReplicateBackend {
35    async fn generate(&self, prompt: &str) -> Result<String, GenerateError> {
36        let replicate = Replicate::new(self.config.clone());
37
38        let mut inputs = HashMap::new();
39        inputs.insert("prompt", prompt);
40
41        let result = replicate
42            .run(self.model.as_str(), inputs)
43            .map_err(|e| GenerateError::BackendError(e.to_string()))?;
44
45        let output = result
46            .output
47            .ok_or(GenerateError::BackendError("No output".to_string()))?;
48
49        let array = output.as_array().ok_or(GenerateError::BackendError(
50            "Output is not an array".to_string(),
51        ))?;
52
53        Ok(array
54            .iter()
55            .map(|x| x.as_str().unwrap_or_default())
56            .collect::<String>()
57            .trim()
58            .to_string())
59    }
60}