Skip to main content

shai/
lib.rs

1#![allow(clippy::future_not_send)]
2
3pub(crate) mod anthropic;
4pub mod cli;
5mod context;
6mod model;
7mod openai;
8mod prompts;
9pub(crate) mod sse_parser;
10
11use anthropic::AnthropicModel;
12use context::Context;
13use futures::Stream;
14use model::Task;
15use openai::OpenAIGPTModel;
16use serde::Deserialize;
17use thiserror::Error;
18
19enum ConfigKind {
20    Ask(AskConfig),
21    Explain(ExplainConfig),
22}
23
24impl ConfigKind {
25    const fn model(&self) -> &ModelKind {
26        match self {
27            Self::Ask(config) => &config.model,
28            Self::Explain(config) => &config.model,
29        }
30    }
31}
32
33#[derive(Deserialize)]
34struct AskConfig {
35    operating_system: String,
36    shell: String,
37    environment: Option<Vec<String>>,
38    programs: Option<Vec<String>>,
39    cwd: Option<()>,
40    depth: Option<u32>,
41    model: ModelKind,
42}
43
44#[derive(Deserialize)]
45struct ExplainConfig {
46    operating_system: String,
47    shell: String,
48    environment: Option<Vec<String>>,
49    model: ModelKind,
50    cwd: Option<()>,
51    depth: Option<u32>,
52}
53
54impl Default for AskConfig {
55    fn default() -> Self {
56        Self {
57            operating_system: "Linux".to_string(),
58            shell: "Bash".to_string(),
59            environment: None,
60            programs: None,
61            cwd: None,
62            depth: None,
63            model: ModelKind::OpenAIGPT(OpenAIGPTModel::GPT4oMini),
64        }
65    }
66}
67
68impl Default for ExplainConfig {
69    fn default() -> Self {
70        Self {
71            operating_system: "Linux".to_string(),
72            shell: "Bash".to_string(),
73            environment: None,
74            cwd: None,
75            depth: None,
76            model: ModelKind::OpenAIGPT(OpenAIGPTModel::GPT4oMini),
77        }
78    }
79}
80
81#[derive(Deserialize, Clone)]
82enum ModelKind {
83    OpenAIGPT(OpenAIGPTModel),
84    Anthropic(AnthropicModel),
85    // OpenAssistant // waiting for a minimal API, go guys :D
86    // Local // ?
87}
88
89#[derive(Debug, Error)]
90pub(crate) enum ModelError {
91    #[error("{0}")]
92    Error(String),
93}
94
95impl From<Box<dyn std::error::Error + Send>> for ModelError {
96    fn from(e: Box<dyn std::error::Error + Send>) -> Self {
97        Self::Error(e.to_string())
98    }
99}
100
101#[allow(unused)]
102async fn model_request(
103    model: ModelKind,
104    request: String,
105    context: Context,
106    task: Task,
107) -> Result<String, ModelError> {
108    match model {
109        ModelKind::OpenAIGPT(model) => model
110            .send(request, context, task)
111            .await
112            .map_err(|err| ModelError::Error(err.to_string())),
113        ModelKind::Anthropic(model) => model
114            .send(request, context, task)
115            .await
116            .map_err(|err| ModelError::Error(err.to_string())),
117    }
118}
119
120async fn model_stream_request(
121    model: ModelKind,
122    request: String,
123    context: Context,
124    task: Task,
125) -> Result<impl Stream<Item = Result<String, ModelError>> + Send, ModelError> {
126    match model {
127        ModelKind::OpenAIGPT(model) => model
128            .send_streaming(request, context, task)
129            .await
130            .map_err(|e| ModelError::Error(e.to_string())),
131        ModelKind::Anthropic(model) => model
132            .send_streaming(request, context, task)
133            .await
134            .map_err(|e| ModelError::Error(e.to_string())),
135    }
136}
137
138fn build_context_request(request: &str, context: Context) -> String {
139    String::from(context) + &format!("Here is your <task>: \n <task>{request}</task>")
140}
141
142// #[cfg(test)]
143// mod tests {
144//     use crate::{
145//         context::Context, model::Task, model_stream_request, openai::OpenAIGPTModel::GPT35Turbo,
146//         AskConfig, ConfigKind, ModelKind::OpenAIGPT,
147//     };
148//     use futures_util::StreamExt;
149//
150//     #[tokio::test]
151//     async fn ssh_tunnel() {
152//         let mut  response_stream = model_stream_request(OpenAIGPT(GPT35Turbo), 
153//             "make an ssh tunnel between port 8080 in this machine and port 1243 in the machine with IP 192.168.0.42".to_string(), 
154//             Context::from(ConfigKind::Ask(AskConfig::default())),
155//             Task::GenerateCommand
156//             ).await.unwrap();
157//         while response_stream.next().await.is_some() {
158//         }
159//     }
160// }