1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
use sllm::message::PromptMessageGroup;

use crate::{Error, Model, ModuleParam, UnitProcess};

#[derive(Debug, Clone)]
pub struct ModelUnit {
    name: String,
    model: Model,
}

impl ModelUnit {
    pub fn new(name: &str, model: Model) -> Self {
        Self {
            name: name.into(),
            model,
        }
    }
}

#[async_trait::async_trait]
impl UnitProcess for ModelUnit {
    fn get_name(&self) -> &str {
        self.name.as_str()
    }

    async fn process(&self, input: ModuleParam) -> Result<ModuleParam, Error> {
        log::debug!("[{}] intput - {:?}", self.name, input);
        // ignore the input
        let groups = match input {
            ModuleParam::Str(req) => {
                let mut group = PromptMessageGroup::new("");
                group.insert("", req.as_str());
                vec![group]
            }
            ModuleParam::MessageBuilders(builder) => builder,
            ModuleParam::None => {
                vec![]
                // return Err(Error::InputRequiredError);
            }
        };

        // generate the response
        self.model
            .generate_response(groups)
            .await
            .map(|result| result.into())
            .map_err(|e| e.into())
    }
}