1use async_trait::async_trait;
2use futures_util::Stream;
3use models::*;
4use reqwest::header::HeaderMap;
5use rmcp::model::Content;
6use stakpak_shared::models::integrations::openai::{
7 ChatCompletionResponse, ChatCompletionStreamResponse, ChatMessage, Tool,
8};
9use uuid::Uuid;
10
11pub mod client;
12pub mod commands;
13pub mod error;
14pub mod local;
15pub mod models;
16pub mod stakpak;
17pub mod storage;
18
19pub use client::{
21 AgentClient, AgentClientConfig, DEFAULT_STAKPAK_ENDPOINT, ModelOptions, StakpakConfig,
22};
23
24pub use stakai::{Model, ModelCost, ModelLimit};
26
27pub use storage::{
29 BoxedSessionStorage, Checkpoint, CheckpointState, CheckpointSummary, CreateCheckpointRequest,
30 CreateSessionRequest as StorageCreateSessionRequest, CreateSessionResult, ListCheckpointsQuery,
31 ListCheckpointsResult, ListSessionsQuery, ListSessionsResult, LocalStorage, Session,
32 SessionStats, SessionStatus, SessionStorage, SessionSummary, SessionVisibility, StakpakStorage,
33 StorageError, UpdateSessionRequest as StorageUpdateSessionRequest,
34};
35
36pub fn find_model(model_str: &str, use_stakpak: bool) -> Option<Model> {
44 const PROVIDERS: &[&str] = &["anthropic", "openai", "google"];
45
46 let (provider_hint, model_id) = parse_model_string(model_str);
47
48 let model = provider_hint
50 .and_then(|p| find_in_provider(p, model_id))
51 .or_else(|| {
52 PROVIDERS
53 .iter()
54 .find_map(|&p| find_in_provider(p, model_id))
55 })?;
56
57 Some(if use_stakpak {
58 transform_for_stakpak(model)
59 } else {
60 model
61 })
62}
63
64#[allow(clippy::string_slice)] fn parse_model_string(s: &str) -> (Option<&str>, &str) {
67 match s.find('/') {
68 Some(idx) => {
69 let provider = &s[..idx];
70 let model_id = &s[idx + 1..];
71 let normalized = match provider {
72 "gemini" => "google",
73 p => p,
74 };
75 (Some(normalized), model_id)
76 }
77 None => (None, s),
78 }
79}
80
81fn find_in_provider(provider_id: &str, model_id: &str) -> Option<Model> {
83 let models = stakai::load_models_for_provider(provider_id).ok()?;
84
85 if let Some(model) = models.iter().find(|m| m.id == model_id) {
87 return Some(model.clone());
88 }
89
90 let mut best_match: Option<&Model> = None;
93 let mut best_len = 0;
94
95 for model in &models {
96 if model_id.starts_with(&model.id) && model.id.len() > best_len {
97 best_match = Some(model);
98 best_len = model.id.len();
99 }
100 }
101
102 best_match.cloned()
103}
104
105pub fn transform_for_stakpak(model: Model) -> Model {
110 Model {
111 id: format!("{}/{}", model.provider, model.id),
112 provider: "stakpak".into(),
113 name: model.name,
114 reasoning: model.reasoning,
115 cost: model.cost,
116 limit: model.limit,
117 release_date: model.release_date,
118 }
119}
120
121#[async_trait]
127pub trait AgentProvider: SessionStorage + Send + Sync {
128 async fn get_my_account(&self) -> Result<GetMyAccountResponse, String>;
130 async fn get_billing_info(
131 &self,
132 account_username: &str,
133 ) -> Result<stakpak_shared::models::billing::BillingResponse, String>;
134
135 async fn list_rulebooks(&self) -> Result<Vec<ListRuleBook>, String>;
137 async fn get_rulebook_by_uri(&self, uri: &str) -> Result<RuleBook, String>;
138 async fn create_rulebook(
139 &self,
140 uri: &str,
141 description: &str,
142 content: &str,
143 tags: Vec<String>,
144 visibility: Option<RuleBookVisibility>,
145 ) -> Result<CreateRuleBookResponse, String>;
146 async fn delete_rulebook(&self, uri: &str) -> Result<(), String>;
147
148 async fn chat_completion(
150 &self,
151 model: Model,
152 messages: Vec<ChatMessage>,
153 tools: Option<Vec<Tool>>,
154 session_id: Option<Uuid>,
155 metadata: Option<serde_json::Value>,
156 ) -> Result<ChatCompletionResponse, String>;
157 async fn chat_completion_stream(
158 &self,
159 model: Model,
160 messages: Vec<ChatMessage>,
161 tools: Option<Vec<Tool>>,
162 headers: Option<HeaderMap>,
163 session_id: Option<Uuid>,
164 metadata: Option<serde_json::Value>,
165 ) -> Result<
166 (
167 std::pin::Pin<
168 Box<dyn Stream<Item = Result<ChatCompletionStreamResponse, ApiStreamError>> + Send>,
169 >,
170 Option<String>,
171 ),
172 String,
173 >;
174 async fn cancel_stream(&self, request_id: String) -> Result<(), String>;
175
176 async fn search_docs(&self, input: &SearchDocsRequest) -> Result<Vec<Content>, String>;
178
179 async fn memorize_session(&self, checkpoint_id: Uuid) -> Result<(), String>;
181 async fn search_memory(&self, input: &SearchMemoryRequest) -> Result<Vec<Content>, String>;
182
183 async fn slack_read_messages(
185 &self,
186 input: &SlackReadMessagesRequest,
187 ) -> Result<Vec<Content>, String>;
188 async fn slack_read_replies(
189 &self,
190 input: &SlackReadRepliesRequest,
191 ) -> Result<Vec<Content>, String>;
192 async fn slack_send_message(
193 &self,
194 input: &SlackSendMessageRequest,
195 ) -> Result<Vec<Content>, String>;
196
197 async fn list_models(&self) -> Vec<Model>;
199}