Skip to main content

stakpak_api/
models.rs

1use std::collections::HashMap;
2
3use chrono::{DateTime, Utc};
4use rmcp::model::Content;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use stakai::Model;
8use stakpak_shared::models::{
9    integrations::openai::{ChatMessage, FunctionCall, MessageContent, Role, Tool, ToolCall},
10    llm::{LLMInput, LLMMessage, LLMMessageContent, LLMMessageTypedContent, LLMTokenUsage},
11};
12use uuid::Uuid;
13
14#[derive(Debug, Clone, Deserialize, Serialize)]
15pub enum ApiStreamError {
16    AgentInputInvalid(String),
17    AgentStateInvalid,
18    AgentNotSupported,
19    AgentExecutionLimitExceeded,
20    AgentInvalidResponseStream,
21    InvalidGeneratedCode,
22    CopilotError,
23    SaveError,
24    Unknown(String),
25}
26
27impl From<&str> for ApiStreamError {
28    fn from(error_str: &str) -> Self {
29        match error_str {
30            s if s.contains("Agent not supported") => ApiStreamError::AgentNotSupported,
31            s if s.contains("Agent state is not valid") => ApiStreamError::AgentStateInvalid,
32            s if s.contains("Agent thinking limit exceeded") => {
33                ApiStreamError::AgentExecutionLimitExceeded
34            }
35            s if s.contains("Invalid response stream") => {
36                ApiStreamError::AgentInvalidResponseStream
37            }
38            s if s.contains("Invalid generated code") => ApiStreamError::InvalidGeneratedCode,
39            s if s.contains(
40                "Our copilot is handling too many requests at this time, please try again later.",
41            ) =>
42            {
43                ApiStreamError::CopilotError
44            }
45            s if s
46                .contains("An error occurred while saving your data. Please try again later.") =>
47            {
48                ApiStreamError::SaveError
49            }
50            s if s.contains("Agent input is not valid: ") => {
51                ApiStreamError::AgentInputInvalid(s.replace("Agent input is not valid: ", ""))
52            }
53            _ => ApiStreamError::Unknown(error_str.to_string()),
54        }
55    }
56}
57
58impl From<String> for ApiStreamError {
59    fn from(error_str: String) -> Self {
60        ApiStreamError::from(error_str.as_str())
61    }
62}
63
64#[derive(Deserialize, Serialize, Debug)]
65pub struct Document {
66    pub content: String,
67    pub uri: String,
68    pub provisioner: ProvisionerType,
69}
70
71#[derive(Deserialize, Serialize, Debug)]
72pub struct SimpleDocument {
73    pub uri: String,
74    pub content: String,
75}
76
77#[derive(Deserialize, Serialize, Debug, Clone)]
78pub struct Block {
79    pub id: Uuid,
80    pub provider: String,
81    pub provisioner: ProvisionerType,
82    pub language: String,
83    pub key: String,
84    pub digest: u64,
85    pub references: Vec<Vec<Segment>>,
86    pub kind: String,
87    pub r#type: Option<String>,
88    pub name: Option<String>,
89    pub config: serde_json::Value,
90    pub document_uri: String,
91    pub code: String,
92    pub start_byte: usize,
93    pub end_byte: usize,
94    pub start_point: Point,
95    pub end_point: Point,
96    pub state: Option<serde_json::Value>,
97    pub updated_at: Option<DateTime<Utc>>,
98    pub created_at: Option<DateTime<Utc>>,
99    pub dependents: Vec<DependentBlock>,
100    pub dependencies: Vec<Dependency>,
101    pub api_group_version: Option<ApiGroupVersion>,
102
103    pub generated_summary: Option<String>,
104}
105
106impl Block {
107    pub fn get_uri(&self) -> String {
108        format!(
109            "{}#L{}-L{}",
110            self.document_uri, self.start_point.row, self.end_point.row
111        )
112    }
113}
114
115#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone)]
116pub enum ProvisionerType {
117    #[serde(rename = "Terraform")]
118    Terraform,
119    #[serde(rename = "Kubernetes")]
120    Kubernetes,
121    #[serde(rename = "Dockerfile")]
122    Dockerfile,
123    #[serde(rename = "GithubActions")]
124    GithubActions,
125    #[serde(rename = "None")]
126    None,
127}
128impl std::str::FromStr for ProvisionerType {
129    type Err = String;
130
131    fn from_str(s: &str) -> Result<Self, Self::Err> {
132        match s.to_lowercase().as_str() {
133            "terraform" => Ok(Self::Terraform),
134            "kubernetes" => Ok(Self::Kubernetes),
135            "dockerfile" => Ok(Self::Dockerfile),
136            "github-actions" => Ok(Self::GithubActions),
137            _ => Ok(Self::None),
138        }
139    }
140}
141impl std::fmt::Display for ProvisionerType {
142    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
143        match self {
144            ProvisionerType::Terraform => write!(f, "terraform"),
145            ProvisionerType::Kubernetes => write!(f, "kubernetes"),
146            ProvisionerType::Dockerfile => write!(f, "dockerfile"),
147            ProvisionerType::GithubActions => write!(f, "github-actions"),
148            ProvisionerType::None => write!(f, "none"),
149        }
150    }
151}
152
153#[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
154#[serde(untagged)]
155pub enum Segment {
156    Key(String),
157    Index(usize),
158}
159
160impl std::fmt::Display for Segment {
161    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162        match self {
163            Segment::Key(key) => write!(f, "{}", key),
164            Segment::Index(index) => write!(f, "{}", index),
165        }
166    }
167}
168impl std::fmt::Debug for Segment {
169    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170        match self {
171            Segment::Key(key) => write!(f, "{}", key),
172            Segment::Index(index) => write!(f, "{}", index),
173        }
174    }
175}
176
177#[derive(Serialize, Deserialize, Debug, Default, Clone, PartialEq)]
178pub struct Point {
179    pub row: usize,
180    pub column: usize,
181}
182
183#[derive(Deserialize, Serialize, Debug, Clone)]
184pub struct DependentBlock {
185    pub key: String,
186}
187
188#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
189pub struct Dependency {
190    pub id: Option<Uuid>,
191    pub expression: Option<String>,
192    pub from_path: Option<Vec<Segment>>,
193    pub to_path: Option<Vec<Segment>>,
194    #[serde(default = "Vec::new")]
195    pub selectors: Vec<DependencySelector>,
196    #[serde(skip_serializing)]
197    pub key: Option<String>,
198    pub digest: Option<u64>,
199    #[serde(default = "Vec::new")]
200    pub from: Vec<Segment>,
201    pub from_field: Option<Vec<Segment>>,
202    pub to_field: Option<Vec<Segment>>,
203    pub start_byte: Option<usize>,
204    pub end_byte: Option<usize>,
205    pub start_point: Option<Point>,
206    pub end_point: Option<Point>,
207    pub satisfied: bool,
208}
209
210#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
211pub struct DependencySelector {
212    pub references: Vec<Vec<Segment>>,
213    pub operator: DependencySelectorOperator,
214}
215
216#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
217pub enum DependencySelectorOperator {
218    Equals,
219    NotEquals,
220    In,
221    NotIn,
222    Exists,
223    DoesNotExist,
224}
225
226#[derive(Serialize, Deserialize, Debug, Clone)]
227pub struct ApiGroupVersion {
228    pub alias: String,
229    pub group: String,
230    pub version: String,
231    pub provisioner: ProvisionerType,
232    pub status: APIGroupVersionStatus,
233}
234
235#[derive(Serialize, Deserialize, Debug, Clone)]
236pub enum APIGroupVersionStatus {
237    #[serde(rename = "UNAVAILABLE")]
238    Unavailable,
239    #[serde(rename = "PENDING")]
240    Pending,
241    #[serde(rename = "AVAILABLE")]
242    Available,
243}
244
245#[derive(Serialize, Deserialize, Debug)]
246pub struct BuildCodeIndexInput {
247    pub documents: Vec<SimpleDocument>,
248}
249
250#[derive(Serialize, Deserialize, Debug, Clone)]
251pub struct IndexError {
252    pub uri: String,
253    pub message: String,
254    pub details: Option<serde_json::Value>,
255}
256
257#[derive(Serialize, Deserialize, Debug, Clone)]
258pub struct BuildCodeIndexOutput {
259    pub blocks: Vec<Block>,
260    pub errors: Vec<IndexError>,
261    pub warnings: Vec<IndexError>,
262}
263
264#[derive(Serialize, Deserialize, Debug, Clone)]
265pub struct CodeIndex {
266    pub last_updated: DateTime<Utc>,
267    pub index: BuildCodeIndexOutput,
268}
269
270#[derive(Serialize, Deserialize, Debug, Clone, Copy, Default)]
271#[serde(rename_all = "UPPERCASE")]
272pub enum RuleBookVisibility {
273    #[default]
274    Public,
275    Private,
276}
277
278#[derive(Serialize, Deserialize, Debug, Clone)]
279pub struct RuleBook {
280    pub id: String,
281    pub uri: String,
282    pub description: String,
283    pub content: String,
284    pub visibility: RuleBookVisibility,
285    pub tags: Vec<String>,
286    pub created_at: Option<DateTime<Utc>>,
287    pub updated_at: Option<DateTime<Utc>>,
288}
289
290#[derive(Serialize, Deserialize, Debug)]
291pub struct ToolsCallParams {
292    pub name: String,
293    pub arguments: Value,
294}
295
296#[derive(Serialize, Deserialize, Debug)]
297pub struct ToolsCallResponse {
298    pub content: Vec<Content>,
299}
300
301#[derive(Serialize, Deserialize, Debug, Clone)]
302pub struct APIKeyScope {
303    pub r#type: String,
304    pub name: String,
305}
306
307impl std::fmt::Display for APIKeyScope {
308    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309        write!(f, "{} ({})", self.name, self.r#type)
310    }
311}
312
313#[derive(Serialize, Deserialize, Clone, Debug)]
314pub struct GetMyAccountResponse {
315    pub username: String,
316    pub id: String,
317    pub first_name: String,
318    pub last_name: String,
319    pub email: String,
320    pub scope: Option<APIKeyScope>,
321}
322
323impl GetMyAccountResponse {
324    pub fn to_text(&self) -> String {
325        format!(
326            "ID: {}\nUsername: {}\nName: {} {}\nEmail: {}",
327            self.id, self.username, self.first_name, self.last_name, self.email
328        )
329    }
330}
331
332#[derive(Serialize, Deserialize, Debug, Clone)]
333pub struct ListRuleBook {
334    pub id: String,
335    pub uri: String,
336    pub description: String,
337    pub visibility: RuleBookVisibility,
338    pub tags: Vec<String>,
339    pub created_at: Option<DateTime<Utc>>,
340    pub updated_at: Option<DateTime<Utc>>,
341}
342
343#[derive(Serialize, Deserialize, Debug)]
344pub struct ListRulebooksResponse {
345    pub results: Vec<ListRuleBook>,
346}
347
348#[derive(Serialize, Deserialize, Debug)]
349pub struct CreateRuleBookInput {
350    pub uri: String,
351    pub description: String,
352    pub content: String,
353    pub tags: Vec<String>,
354    #[serde(skip_serializing_if = "Option::is_none")]
355    pub visibility: Option<RuleBookVisibility>,
356}
357
358#[derive(Serialize, Deserialize, Debug)]
359pub struct CreateRuleBookResponse {
360    pub id: String,
361}
362
363impl ListRuleBook {
364    pub fn to_text(&self) -> String {
365        format!(
366            "URI: {}\nDescription: {}\nTags: {}\n",
367            self.uri,
368            self.description,
369            self.tags.join(", ")
370        )
371    }
372}
373
374#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
375pub struct SimpleLLMMessage {
376    #[serde(rename = "role")]
377    pub role: SimpleLLMRole,
378    pub content: String,
379}
380
381#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
382#[serde(rename_all = "lowercase")]
383pub enum SimpleLLMRole {
384    User,
385    Assistant,
386}
387
388impl std::fmt::Display for SimpleLLMRole {
389    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
390        match self {
391            SimpleLLMRole::User => write!(f, "user"),
392            SimpleLLMRole::Assistant => write!(f, "assistant"),
393        }
394    }
395}
396
397#[derive(Debug, Deserialize, Serialize)]
398pub struct SearchDocsRequest {
399    pub keywords: String,
400    pub exclude_keywords: Option<String>,
401    pub limit: Option<u32>,
402}
403
404#[derive(Debug, Deserialize, Serialize)]
405pub struct SearchMemoryRequest {
406    pub keywords: Vec<String>,
407    pub start_time: Option<DateTime<Utc>>,
408    pub end_time: Option<DateTime<Utc>>,
409}
410
411#[derive(Debug, Deserialize, Serialize)]
412pub struct SlackReadMessagesRequest {
413    pub channel: String,
414    pub limit: Option<u32>,
415}
416
417#[derive(Debug, Deserialize, Serialize)]
418pub struct SlackReadRepliesRequest {
419    pub channel: String,
420    pub ts: String,
421}
422
423#[derive(Debug, Deserialize, Serialize)]
424pub struct SlackSendMessageRequest {
425    pub channel: String,
426    pub markdown_text: String,
427    pub thread_ts: Option<String>,
428}
429
430#[derive(Debug, Clone, Default, Serialize)]
431pub struct AgentState {
432    /// The active model to use for inference
433    pub active_model: Model,
434    pub messages: Vec<ChatMessage>,
435    pub tools: Option<Vec<Tool>>,
436
437    pub llm_input: Option<LLMInput>,
438    pub llm_output: Option<LLMOutput>,
439
440    pub metadata: Option<HashMap<String, Value>>,
441}
442
443#[derive(Debug, Clone, Default, Serialize)]
444pub struct LLMOutput {
445    pub new_message: LLMMessage,
446    pub usage: LLMTokenUsage,
447}
448
449impl From<&LLMOutput> for ChatMessage {
450    fn from(value: &LLMOutput) -> Self {
451        let message_content = match &value.new_message.content {
452            LLMMessageContent::String(s) => s.clone(),
453            LLMMessageContent::List(l) => l
454                .iter()
455                .map(|c| match c {
456                    LLMMessageTypedContent::Text { text } => text.clone(),
457                    LLMMessageTypedContent::ToolCall { .. } => String::new(),
458                    LLMMessageTypedContent::ToolResult { content, .. } => content.clone(),
459                    LLMMessageTypedContent::Image { .. } => String::new(),
460                })
461                .collect::<Vec<_>>()
462                .join("\n"),
463        };
464        let tool_calls = if let LLMMessageContent::List(items) = &value.new_message.content {
465            let calls: Vec<ToolCall> = items
466                .iter()
467                .filter_map(|item| {
468                    if let LLMMessageTypedContent::ToolCall { id, name, args } = item {
469                        Some(ToolCall {
470                            id: id.clone(),
471                            r#type: "function".to_string(),
472                            function: FunctionCall {
473                                name: name.clone(),
474                                arguments: args.to_string(),
475                            },
476                        })
477                    } else {
478                        None
479                    }
480                })
481                .collect();
482
483            if calls.is_empty() { None } else { Some(calls) }
484        } else {
485            None
486        };
487        ChatMessage {
488            role: Role::Assistant,
489            content: Some(MessageContent::String(message_content)),
490            name: None,
491            tool_calls,
492            tool_call_id: None,
493            usage: Some(value.usage.clone()),
494            ..Default::default()
495        }
496    }
497}
498
499impl AgentState {
500    pub fn new(active_model: Model, messages: Vec<ChatMessage>, tools: Option<Vec<Tool>>) -> Self {
501        Self {
502            active_model,
503            messages,
504            tools,
505            llm_input: None,
506            llm_output: None,
507            metadata: None,
508        }
509    }
510
511    pub fn set_messages(&mut self, messages: Vec<ChatMessage>) {
512        self.messages = messages;
513    }
514
515    pub fn set_tools(&mut self, tools: Option<Vec<Tool>>) {
516        self.tools = tools;
517    }
518
519    pub fn set_active_model(&mut self, model: Model) {
520        self.active_model = model;
521    }
522
523    pub fn set_llm_input(&mut self, llm_input: Option<LLMInput>) {
524        self.llm_input = llm_input;
525    }
526
527    pub fn set_llm_output(&mut self, new_message: LLMMessage, new_usage: Option<LLMTokenUsage>) {
528        self.llm_output = Some(LLMOutput {
529            new_message,
530            usage: new_usage.unwrap_or_default(),
531        });
532    }
533
534    pub fn append_new_message(&mut self, new_message: ChatMessage) {
535        self.messages.push(new_message);
536    }
537}