1use std::sync::Arc;
2
3use sllm::message::{MessageBuilder, PromptMessage};
4
5pub mod sync;
6pub mod units;
7
8mod error;
9mod pipeline_net;
10mod prompt_manager;
11mod traits;
12
13pub use error::Error;
14pub use pipeline_net::PipelineNet;
15pub use prompt_manager::PromptManager;
16pub use sllm::Backend;
17pub use traits::*;
18
19pub trait ToKeywordString {
20 fn to_keyword_string() -> String;
21}
22
23pub mod prelude {
24 pub use super::ToKeywordString;
25 pub use ai_agent_macro::*;
26 pub use sllm::message::{MessageBuilder, PromptMessage, TemplatedMessage};
27}
28
29#[derive(Debug, Clone)]
30pub enum ModuleParam {
31 Str(String),
32 MessageBuilders(Vec<PromptMessage>),
33 None,
34}
35
36impl ModuleParam {
37 pub fn is_none(&self) -> bool {
38 match self {
39 Self::None => true,
40 _ => false,
41 }
42 }
43
44 pub fn into_message_group(self) -> Option<Vec<PromptMessage>> {
45 match self {
46 Self::MessageBuilders(group) => Some(group),
47 _ => None,
48 }
49 }
50
51 pub fn into_string(self) -> Option<String> {
52 match self {
53 Self::Str(s) => Some(s),
54 _ => None,
55 }
56 }
57
58 pub fn as_message_group(&self) -> Option<&Vec<PromptMessage>> {
59 match self {
60 Self::MessageBuilders(group) => Some(group),
61 _ => None,
62 }
63 }
64
65 pub fn as_string(&self) -> Option<&String> {
66 match self {
67 Self::Str(s) => Some(s),
68 _ => None,
69 }
70 }
71}
72
73impl Default for ModuleParam {
74 fn default() -> Self {
75 Self::None
76 }
77}
78
79impl From<&str> for ModuleParam {
80 fn from(val: &str) -> Self {
81 ModuleParam::Str(val.into())
82 }
83}
84
85impl From<Vec<PromptMessage>> for ModuleParam {
86 fn from(val: Vec<PromptMessage>) -> Self {
87 ModuleParam::MessageBuilders(val)
88 }
89}
90
91impl From<String> for ModuleParam {
92 fn from(val: String) -> Self {
93 ModuleParam::Str(val)
94 }
95}
96
97#[derive(Debug, Clone)]
100pub struct Model {
101 model: Arc<sync::Mutex<sllm::Model>>,
102}
103
104impl Model {
105 pub fn new(backend: Backend) -> Result<Self, Error> {
106 let model = sllm::Model::new(backend)?;
107 Ok(Self {
108 model: Arc::new(sync::Mutex::new(model)),
109 })
110 }
111
112 pub async fn set_temperature(&self, temperature: f64) {
113 let mut model = self.model.lock().await;
114 model.set_temperature(temperature);
115 }
116
117 pub async fn generate_response<T>(&self, input: T) -> Result<String, Error>
118 where
119 T: IntoIterator + Send,
120 T::Item: MessageBuilder + Send,
121 {
122 let model = self.model.lock().await;
123 let result = model.generate_response(input).await?;
124 Ok(result)
125 }
126}
127
128#[cfg(test)]
131mod tests {
132 use super::Model;
133
134 pub fn get_model() -> Model {
135 dotenv::dotenv().ok();
136 Model::new(sllm::Backend::ChatGPT {
137 api_key: std::env::var("OPEN_API_KEY").unwrap(),
138 model: "gpt-3.5-turbo".into(),
139 })
140 .unwrap()
141 }
142
143 use super::ToKeywordString;
144 use ai_agent_macro::KeywordString;
145
146 #[allow(dead_code)]
147 #[derive(KeywordString)]
148 struct SubStruct {
149 prop1: i32,
150 prop2: f32,
151 prop3: String,
152 }
153
154 #[allow(dead_code)]
155 #[derive(KeywordString)]
156 struct TestStruct {
157 sub: SubStruct,
158 prop: Vec<SubStruct>,
159 }
160
161 #[ignore]
162 #[test]
163 fn test_print_keyword() {
164 assert_eq!(
165 TestStruct::to_keyword_string(),
166 "{sub{prop1, prop2, prop3}, prop[{prop1, prop2, prop3}]}"
167 );
168 }
169}