1use std::collections::HashMap;
2
3use replicate_rust::{config::Config, Replicate};
4
5use crate::{GenerateError, LlmBackend};
6
7pub struct ReplicateBackend {
8 pub model: ReplicateModel,
9 config: Config,
10}
11
12#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
13pub enum ReplicateModel {
14 Llama2,
15 #[default]
16 Mistral7B,
17}
18
19impl ReplicateModel {
20 pub fn as_str(&self) -> &str {
21 match self {
22 Self::Llama2 => "meta/llama-2-7b:73001d654114dad81ec65da3b834e2f691af1e1526453189b7bf36fb3f32d0f9",
23 Self::Mistral7B => "mistralai/mistral-7b-instruct-v0.1:83b6a56e7c828e667f21fd596c338fd4f0039b46bcfa18d973e8e70e455fda70",
24 }
25 }
26}
27
28impl ReplicateBackend {
29 pub fn new(model: ReplicateModel, config: Config) -> Self {
30 Self { model, config }
31 }
32}
33
34impl LlmBackend for ReplicateBackend {
35 async fn generate(&self, prompt: &str) -> Result<String, GenerateError> {
36 let replicate = Replicate::new(self.config.clone());
37
38 let mut inputs = HashMap::new();
39 inputs.insert("prompt", prompt);
40
41 let result = replicate
42 .run(self.model.as_str(), inputs)
43 .map_err(|e| GenerateError::BackendError(e.to_string()))?;
44
45 let output = result
46 .output
47 .ok_or(GenerateError::BackendError("No output".to_string()))?;
48
49 let array = output.as_array().ok_or(GenerateError::BackendError(
50 "Output is not an array".to_string(),
51 ))?;
52
53 Ok(array
54 .iter()
55 .map(|x| x.as_str().unwrap_or_default())
56 .collect::<String>()
57 .trim()
58 .to_string())
59 }
60}