1use crate::local::hooks::inline_scratchpad_context::{
5 InlineScratchpadContextHook, InlineScratchpadContextHookOptions,
6};
7use crate::{AgentProvider, ApiStreamError, GetMyAccountResponse};
8use crate::{ListRuleBook, models::*};
9use async_trait::async_trait;
10use futures_util::Stream;
11use libsql::{Builder, Connection};
12use reqwest::Error as ReqwestError;
13use reqwest::header::HeaderMap;
14use rmcp::model::Content;
15use stakpak_shared::hooks::{HookContext, HookRegistry, LifecycleEvent};
16use stakpak_shared::models::integrations::anthropic::{AnthropicConfig, AnthropicModel};
17use stakpak_shared::models::integrations::gemini::{GeminiConfig, GeminiModel};
18use stakpak_shared::models::integrations::openai::{
19 AgentModel, ChatCompletionChoice, ChatCompletionResponse, ChatCompletionStreamChoice,
20 ChatCompletionStreamResponse, ChatMessage, FinishReason, MessageContent, OpenAIConfig,
21 OpenAIModel, Role, Tool,
22};
23use stakpak_shared::models::integrations::search_service::*;
24use stakpak_shared::models::llm::{
25 GenerationDelta, LLMInput, LLMMessage, LLMMessageContent, LLMModel, LLMProviderConfig,
26 LLMStreamInput,
27};
28use stakpak_shared::models::stakai_adapter::StakAIClient;
29use stakpak_shared::tls_client::{TlsClientConfig, create_tls_client};
30use std::pin::Pin;
31use std::sync::Arc;
32use tokio::sync::mpsc;
33use uuid::Uuid;
34
35mod context_managers;
36mod db;
37mod hooks;
38
39#[cfg(test)]
40mod tests;
41
42#[derive(Clone, Debug)]
43pub struct LocalClient {
44 pub db: Connection,
45 pub stakpak_base_url: Option<String>,
46 pub anthropic_config: Option<AnthropicConfig>,
47 pub openai_config: Option<OpenAIConfig>,
48 pub gemini_config: Option<GeminiConfig>,
49 pub model_options: ModelOptions,
50 pub hook_registry: Option<Arc<HookRegistry<AgentState>>>,
51 _search_services_orchestrator: Option<Arc<SearchServicesOrchestrator>>,
52}
53
54#[derive(Clone, Debug)]
55pub struct ModelOptions {
56 pub smart_model: Option<LLMModel>,
57 pub eco_model: Option<LLMModel>,
58 pub recovery_model: Option<LLMModel>,
59}
60
61#[derive(Clone, Debug)]
62pub struct ModelSet {
63 pub smart_model: LLMModel,
64 pub eco_model: LLMModel,
65 pub recovery_model: LLMModel,
66 pub hook_registry: Option<Arc<HookRegistry<AgentState>>>,
67 pub _search_services_orchestrator: Option<Arc<SearchServicesOrchestrator>>,
68}
69
70impl ModelSet {
71 fn get_model(&self, agent_model: &AgentModel) -> LLMModel {
72 match agent_model {
73 AgentModel::Smart => self.smart_model.clone(),
74 AgentModel::Eco => self.eco_model.clone(),
75 AgentModel::Recovery => self.recovery_model.clone(),
76 }
77 }
78}
79
80impl From<ModelOptions> for ModelSet {
81 fn from(value: ModelOptions) -> Self {
82 let smart_model = value
83 .smart_model
84 .unwrap_or(LLMModel::Anthropic(AnthropicModel::Claude45Sonnet));
85 let eco_model = value
86 .eco_model
87 .unwrap_or(LLMModel::Anthropic(AnthropicModel::Claude45Haiku));
88 let recovery_model = value
89 .recovery_model
90 .unwrap_or(LLMModel::OpenAI(OpenAIModel::GPT5));
91
92 Self {
93 smart_model,
94 eco_model,
95 recovery_model,
96 hook_registry: None,
97 _search_services_orchestrator: None,
98 }
99 }
100}
101
102pub struct LocalClientConfig {
103 pub stakpak_base_url: Option<String>,
104 pub store_path: Option<String>,
105 pub anthropic_config: Option<AnthropicConfig>,
106 pub openai_config: Option<OpenAIConfig>,
107 pub gemini_config: Option<GeminiConfig>,
108 pub smart_model: Option<String>,
109 pub eco_model: Option<String>,
110 pub recovery_model: Option<String>,
111 pub hook_registry: Option<HookRegistry<AgentState>>,
112}
113
114#[derive(Debug)]
115enum StreamMessage {
116 Delta(GenerationDelta),
117 Ctx(Box<HookContext<AgentState>>),
118}
119
120const DEFAULT_STORE_PATH: &str = ".stakpak/data/local.db";
121const TITLE_GENERATOR_PROMPT: &str = include_str!("./prompts/session_title_generator.v1.txt");
122
123impl LocalClient {
124 pub async fn new(config: LocalClientConfig) -> Result<Self, String> {
125 let store_path = config
126 .store_path
127 .map(std::path::PathBuf::from)
128 .unwrap_or_else(|| {
129 std::env::home_dir()
130 .unwrap_or_default()
131 .join(DEFAULT_STORE_PATH)
132 });
133
134 if let Some(parent) = store_path.parent() {
135 std::fs::create_dir_all(parent)
136 .map_err(|e| format!("Failed to create database directory: {}", e))?;
137 }
138
139 let db = Builder::new_local(store_path.display().to_string())
140 .build()
141 .await
142 .map_err(|e| e.to_string())?;
143
144 let conn = db.connect().map_err(|e| e.to_string())?;
145
146 db::init_schema(&conn).await?;
148
149 let model_options = ModelOptions {
150 smart_model: config.smart_model.map(LLMModel::from),
151 eco_model: config.eco_model.map(LLMModel::from),
152 recovery_model: config.recovery_model.map(LLMModel::from),
153 };
154
155 let mut hook_registry = config.hook_registry.unwrap_or_default();
157 hook_registry.register(
158 LifecycleEvent::BeforeInference,
159 Box::new(InlineScratchpadContextHook::new(
160 InlineScratchpadContextHookOptions {
161 model_options: model_options.clone(),
162 history_action_message_size_limit: Some(100),
163 history_action_message_keep_last_n: Some(1),
164 history_action_result_keep_last_n: Some(50),
165 },
166 )),
167 );
168 Ok(Self {
184 db: conn,
185 stakpak_base_url: config.stakpak_base_url.map(|url| url + "/v1"),
186 anthropic_config: config.anthropic_config,
187 gemini_config: config.gemini_config,
188 openai_config: config.openai_config,
189 model_options,
190 hook_registry: Some(Arc::new(hook_registry)),
191 _search_services_orchestrator: Some(Arc::new(SearchServicesOrchestrator)),
192 })
193 }
194}
195
196#[async_trait]
197impl AgentProvider for LocalClient {
198 async fn get_my_account(&self) -> Result<GetMyAccountResponse, String> {
199 Ok(GetMyAccountResponse {
200 username: "local".to_string(),
201 id: "local".to_string(),
202 first_name: "local".to_string(),
203 last_name: "local".to_string(),
204 })
205 }
206
207 async fn list_rulebooks(&self) -> Result<Vec<ListRuleBook>, String> {
208 if self.stakpak_base_url.is_none() {
209 return Ok(vec![]);
210 }
211
212 let stakpak_base_url = self
213 .stakpak_base_url
214 .as_ref()
215 .ok_or("Stakpak base URL not set")?;
216
217 let url = format!("{}/rules", stakpak_base_url);
218
219 let client = create_tls_client(
220 TlsClientConfig::default().with_timeout(std::time::Duration::from_secs(300)),
221 )?;
222
223 let response = client
224 .get(url)
225 .send()
226 .await
227 .map_err(|e: ReqwestError| e.to_string())?;
228
229 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
230
231 match serde_json::from_value::<ListRulebooksResponse>(value.clone()) {
232 Ok(response) => Ok(response.results),
233 Err(e) => {
234 eprintln!("Failed to deserialize response: {}", e);
235 eprintln!("Raw response: {}", value);
236 Err("Failed to deserialize response:".into())
237 }
238 }
239 }
240
241 async fn get_rulebook_by_uri(&self, uri: &str) -> Result<RuleBook, String> {
242 let stakpak_base_url = self
243 .stakpak_base_url
244 .as_ref()
245 .ok_or("Stakpak base URL not set")?;
246
247 let encoded_uri = urlencoding::encode(uri);
248
249 let url = format!("{}/rules/{}", stakpak_base_url, encoded_uri);
250
251 let client = create_tls_client(
252 TlsClientConfig::default().with_timeout(std::time::Duration::from_secs(300)),
253 )?;
254
255 let response = client
256 .get(&url)
257 .send()
258 .await
259 .map_err(|e: ReqwestError| e.to_string())?;
260
261 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
262
263 match serde_json::from_value::<RuleBook>(value.clone()) {
264 Ok(response) => Ok(response),
265 Err(e) => {
266 eprintln!("Failed to deserialize response: {}", e);
267 eprintln!("Raw response: {}", value);
268 Err("Failed to deserialize response:".into())
269 }
270 }
271 }
272
273 async fn create_rulebook(
274 &self,
275 _uri: &str,
276 _description: &str,
277 _content: &str,
278 _tags: Vec<String>,
279 _visibility: Option<RuleBookVisibility>,
280 ) -> Result<CreateRuleBookResponse, String> {
281 Err("Local provider does not support rulebooks yet".to_string())
283 }
284
285 async fn delete_rulebook(&self, _uri: &str) -> Result<(), String> {
286 Err("Local provider does not support rulebooks yet".to_string())
288 }
289
290 async fn list_agent_sessions(&self) -> Result<Vec<AgentSession>, String> {
291 db::list_sessions(&self.db).await
292 }
293
294 async fn get_agent_session(&self, session_id: Uuid) -> Result<AgentSession, String> {
295 db::get_session(&self.db, session_id).await
296 }
297
298 async fn get_agent_session_stats(
299 &self,
300 _session_id: Uuid,
301 ) -> Result<AgentSessionStats, String> {
302 Ok(AgentSessionStats::default())
304 }
305
306 async fn get_agent_checkpoint(&self, checkpoint_id: Uuid) -> Result<RunAgentOutput, String> {
307 db::get_checkpoint(&self.db, checkpoint_id).await
308 }
309
310 async fn get_agent_session_latest_checkpoint(
311 &self,
312 session_id: Uuid,
313 ) -> Result<RunAgentOutput, String> {
314 db::get_latest_checkpoint(&self.db, session_id).await
315 }
316
317 async fn chat_completion(
318 &self,
319 model: AgentModel,
320 messages: Vec<ChatMessage>,
321 tools: Option<Vec<Tool>>,
322 ) -> Result<ChatCompletionResponse, String> {
323 let mut ctx = HookContext::new(None, AgentState::new(model, messages, tools));
324
325 if let Some(hook_registry) = &self.hook_registry {
326 hook_registry
327 .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
328 .await
329 .map_err(|e| e.to_string())?
330 .ok()?;
331 }
332
333 let current_checkpoint = self.initialize_session(&ctx.state.messages).await?;
334 ctx.set_session_id(current_checkpoint.session.id);
335
336 let new_message = self.run_agent_completion(&mut ctx, None).await?;
337 ctx.state.append_new_message(new_message.clone());
338
339 let result = self
340 .update_session(¤t_checkpoint, ctx.state.messages.clone())
341 .await?;
342 let checkpoint_created_at = result.checkpoint.created_at.timestamp() as u64;
343 ctx.set_new_checkpoint_id(result.checkpoint.id);
344
345 if let Some(hook_registry) = &self.hook_registry {
346 hook_registry
347 .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
348 .await
349 .map_err(|e| e.to_string())?
350 .ok()?;
351 }
352
353 Ok(ChatCompletionResponse {
354 id: ctx.new_checkpoint_id.unwrap().to_string(),
355 object: "chat.completion".to_string(),
356 created: checkpoint_created_at,
357 model: ctx
358 .state
359 .llm_input
360 .as_ref()
361 .map(|llm_input| llm_input.model.clone().to_string())
362 .unwrap_or_default(),
363 choices: vec![ChatCompletionChoice {
364 index: 0,
365 message: ctx.state.messages.last().cloned().unwrap(),
366 logprobs: None,
367 finish_reason: FinishReason::Stop,
368 }],
369 usage: ctx
370 .state
371 .llm_output
372 .as_ref()
373 .map(|u| u.usage.clone())
374 .unwrap_or_default(),
375 system_fingerprint: None,
376 metadata: None,
377 })
378 }
379
380 async fn chat_completion_stream(
381 &self,
382 model: AgentModel,
383 messages: Vec<ChatMessage>,
384 tools: Option<Vec<Tool>>,
385 _headers: Option<HeaderMap>,
386 ) -> Result<
387 (
388 Pin<
389 Box<dyn Stream<Item = Result<ChatCompletionStreamResponse, ApiStreamError>> + Send>,
390 >,
391 Option<String>,
392 ),
393 String,
394 > {
395 let mut ctx = HookContext::new(None, AgentState::new(model, messages, tools));
396
397 if let Some(hook_registry) = &self.hook_registry {
398 hook_registry
399 .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
400 .await
401 .map_err(|e| e.to_string())?
402 .ok()?;
403 }
404
405 let current_checkpoint = self.initialize_session(&ctx.state.messages).await?;
406 ctx.set_session_id(current_checkpoint.session.id);
407
408 let (tx, mut rx) = mpsc::channel::<Result<StreamMessage, String>>(100);
409
410 let _ = tx
411 .send(Ok(StreamMessage::Delta(GenerationDelta::Content {
412 content: format!(
413 "\n<checkpoint_id>{}</checkpoint_id>\n",
414 current_checkpoint.checkpoint.id
415 ),
416 })))
417 .await;
418
419 let client = self.clone();
420 let self_clone = self.clone();
421 let mut ctx_clone = ctx.clone();
422 tokio::spawn(async move {
423 let result = client
424 .run_agent_completion(&mut ctx_clone, Some(tx.clone()))
425 .await;
426
427 match result {
428 Err(e) => {
429 let _ = tx.send(Err(e)).await;
430 }
431 Ok(new_message) => {
432 ctx_clone.state.append_new_message(new_message.clone());
433 let _ = tx
434 .send(Ok(StreamMessage::Ctx(Box::new(ctx_clone.clone()))))
435 .await;
436
437 let output = self_clone
438 .update_session(¤t_checkpoint, ctx_clone.state.messages.clone())
439 .await;
440
441 match output {
442 Err(e) => {
443 let _ = tx.send(Err(e)).await;
444 }
445 Ok(output) => {
446 ctx_clone.set_new_checkpoint_id(output.checkpoint.id);
447 let _ = tx.send(Ok(StreamMessage::Ctx(Box::new(ctx_clone)))).await;
448 let _ = tx
449 .send(Ok(StreamMessage::Delta(GenerationDelta::Content {
450 content: format!(
451 "\n<checkpoint_id>{}</checkpoint_id>\n",
452 output.checkpoint.id
453 ),
454 })))
455 .await;
456 }
457 }
458 }
459 }
460 });
461
462 let hook_registry = self.hook_registry.clone();
463 let stream = async_stream::stream! {
464 while let Some(delta_result) = rx.recv().await {
465 match delta_result {
466 Ok(delta) => match delta {
467 StreamMessage::Ctx(updated_ctx) => {
468 ctx = *updated_ctx;
469 }
470 StreamMessage::Delta(delta) => {
471 yield Ok(ChatCompletionStreamResponse {
472 id: ctx.request_id.to_string(),
473 object: "chat.completion.chunk".to_string(),
474 created: chrono::Utc::now().timestamp() as u64,
475 model: ctx.state.llm_input.as_ref().map(|llm_input| llm_input.model.clone().to_string()).unwrap_or_default(),
476 choices: vec![ChatCompletionStreamChoice {
477 index: 0,
478 delta: delta.into(),
479 finish_reason: None,
480 }],
481 usage: ctx.state.llm_output.as_ref().map(|u| u.usage.clone()),
482 metadata: None,
483 })
484 }
485 }
486 Err(e) => yield Err(ApiStreamError::Unknown(e)),
487 }
488 }
489
490 if let Some(hook_registry) = hook_registry {
491 hook_registry
492 .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
493 .await
494 .map_err(|e| e.to_string())?
495 .ok()?;
496 }
497 };
498
499 Ok((Box::pin(stream), None))
500 }
501
502 async fn cancel_stream(&self, _request_id: String) -> Result<(), String> {
503 Ok(())
504 }
505
506 async fn search_docs(&self, input: &SearchDocsRequest) -> Result<Vec<Content>, String> {
507 let config = SearchServicesOrchestrator::start()
508 .await
509 .map_err(|e| e.to_string())?;
510
511 let api_url = format!("http://localhost:{}", config.api_port);
520 let search_client = SearchClient::new(api_url);
521
522 let initial_query = if let Some(exclude) = &input.exclude_keywords {
523 format!("{} -{}", input.keywords, exclude)
524 } else {
525 input.keywords.clone()
526 };
527
528 let llm_config = self.get_llm_config();
529 let search_model = get_search_model(
530 &llm_config,
531 self.model_options.eco_model.clone(),
532 self.model_options.smart_model.clone(),
533 );
534
535 let analysis = analyze_search_query(&llm_config, &search_model, &initial_query).await?;
536 let required_documentation = analysis.required_documentation;
537 let mut current_query = analysis.reformulated_query;
538 let mut previous_queries = Vec::new();
539 let mut final_valid_docs = Vec::new();
540 let mut accumulated_needed_urls = Vec::new();
541
542 const MAX_ITERATIONS: usize = 3;
543
544 for _iteration in 0..MAX_ITERATIONS {
545 previous_queries.push(current_query.clone());
546
547 let search_results = search_client
548 .search_and_scrape(current_query.clone(), None)
549 .await
550 .map_err(|e| e.to_string())?;
551
552 if search_results.is_empty() {
553 break;
554 }
555
556 let validation_result = validate_search_docs(
557 &llm_config,
558 &search_model,
559 &search_results,
560 ¤t_query,
561 &required_documentation,
562 &previous_queries,
563 &accumulated_needed_urls,
564 )
565 .await?;
566
567 for url in &validation_result.needed_urls {
568 if !accumulated_needed_urls.contains(url) {
569 accumulated_needed_urls.push(url.clone());
570 }
571 }
572
573 for doc in validation_result.valid_docs.into_iter() {
574 let is_duplicate = final_valid_docs
575 .iter()
576 .any(|existing_doc: &ScrapedContent| existing_doc.url == doc.url);
577
578 if !is_duplicate {
579 final_valid_docs.push(doc);
580 }
581 }
582
583 if validation_result.is_satisfied {
584 break;
585 }
586
587 if let Some(new_query) = validation_result.new_query {
588 if new_query != current_query && !previous_queries.contains(&new_query) {
589 current_query = new_query;
590 } else {
591 break;
592 }
593 } else {
594 break;
595 }
596 }
597
598 if final_valid_docs.is_empty() {
599 return Ok(vec![Content::text("No results found".to_string())]);
600 }
601
602 let contents: Vec<Content> = final_valid_docs
603 .into_iter()
604 .map(|result| {
605 Content::text(format!(
606 "Title: {}\nURL: {}\nContent: {}",
607 result.title.unwrap_or_default(),
608 result.url,
609 result.content.unwrap_or_default(),
610 ))
611 })
612 .collect();
613
614 Ok(contents)
615 }
616
617 async fn search_memory(&self, _input: &SearchMemoryRequest) -> Result<Vec<Content>, String> {
618 Ok(Vec::new())
620 }
621
622 async fn slack_read_messages(
623 &self,
624 _input: &SlackReadMessagesRequest,
625 ) -> Result<Vec<Content>, String> {
626 Ok(Vec::new())
628 }
629
630 async fn slack_read_replies(
631 &self,
632 _input: &SlackReadRepliesRequest,
633 ) -> Result<Vec<Content>, String> {
634 Ok(Vec::new())
636 }
637
638 async fn slack_send_message(
639 &self,
640 _input: &SlackSendMessageRequest,
641 ) -> Result<Vec<Content>, String> {
642 Ok(Vec::new())
644 }
645
646 async fn memorize_session(&self, _checkpoint_id: Uuid) -> Result<(), String> {
647 Ok(())
649 }
650}
651
652impl LocalClient {
653 fn get_llm_config(&self) -> LLMProviderConfig {
654 LLMProviderConfig {
655 anthropic_config: self.anthropic_config.clone(),
656 openai_config: self.openai_config.clone(),
657 gemini_config: self.gemini_config.clone(),
658 }
659 }
660
661 async fn run_agent_completion(
662 &self,
663 ctx: &mut HookContext<AgentState>,
664 stream_channel_tx: Option<mpsc::Sender<Result<StreamMessage, String>>>,
665 ) -> Result<ChatMessage, String> {
666 if let Some(hook_registry) = &self.hook_registry {
667 hook_registry
668 .execute_hooks(ctx, &LifecycleEvent::BeforeInference)
669 .await
670 .map_err(|e| e.to_string())?
671 .ok()?;
672 }
673
674 let input = if let Some(llm_input) = ctx.state.llm_input.clone() {
675 llm_input
676 } else {
677 return Err(
678 "Run agent completion: LLM input not found, make sure to register a context hook before inference"
679 .to_string(),
680 );
681 };
682
683 let llm_config = self.get_llm_config();
684 let stakai_client = StakAIClient::new(&llm_config)
685 .map_err(|e| format!("Failed to create StakAI client: {}", e))?;
686
687 let (response_message, usage) = if let Some(tx) = stream_channel_tx {
688 let (internal_tx, mut internal_rx) = mpsc::channel::<GenerationDelta>(100);
689 let input = LLMStreamInput {
690 model: input.model,
691 messages: input.messages,
692 max_tokens: input.max_tokens,
693 tools: input.tools,
694 stream_channel_tx: internal_tx,
695 provider_options: input.provider_options,
696 };
697
698 let chat_future = async move {
699 stakai_client
700 .chat_stream(input)
701 .await
702 .map_err(|e| e.to_string())
703 };
704
705 let receive_future = async move {
706 while let Some(delta) = internal_rx.recv().await {
707 if tx.send(Ok(StreamMessage::Delta(delta))).await.is_err() {
708 break;
709 }
710 }
711 };
712
713 let (chat_result, _) = tokio::join!(chat_future, receive_future);
714 let response = chat_result?;
715 (response.choices[0].message.clone(), response.usage)
716 } else {
717 let response = stakai_client.chat(input).await.map_err(|e| e.to_string())?;
718 (response.choices[0].message.clone(), response.usage)
719 };
720
721 ctx.state.set_llm_output(response_message, usage);
722
723 if let Some(hook_registry) = &self.hook_registry {
724 hook_registry
725 .execute_hooks(ctx, &LifecycleEvent::AfterInference)
726 .await
727 .map_err(|e| e.to_string())?
728 .ok()?;
729 }
730
731 let llm_output = ctx
732 .state
733 .llm_output
734 .as_ref()
735 .ok_or_else(|| "LLM output is missing from state".to_string())?;
736
737 Ok(ChatMessage::from(llm_output))
738 }
739
740 async fn initialize_session(&self, messages: &[ChatMessage]) -> Result<RunAgentOutput, String> {
741 if messages.is_empty() {
743 return Err("At least one message is required".to_string());
744 }
745
746 let checkpoint_id = ChatMessage::last_server_message(messages).and_then(|message| {
748 message
749 .content
750 .as_ref()
751 .and_then(|content| content.extract_checkpoint_id())
752 });
753
754 let current_checkpoint = if let Some(checkpoint_id) = checkpoint_id {
755 db::get_checkpoint(&self.db, checkpoint_id).await?
756 } else {
757 let title = self.generate_session_title(messages).await?;
758
759 let session_id = Uuid::new_v4();
761 let now = chrono::Utc::now();
762 let session = AgentSession {
763 id: session_id,
764 title,
765 agent_id: AgentID::PabloV1,
766 visibility: AgentSessionVisibility::Private,
767 created_at: now,
768 updated_at: now,
769 checkpoints: vec![],
770 };
771 db::create_session(&self.db, &session).await?;
772
773 let checkpoint_id = Uuid::new_v4();
775 let checkpoint = AgentCheckpointListItem {
776 id: checkpoint_id,
777 status: AgentStatus::Complete,
778 execution_depth: 0,
779 parent: None,
780 created_at: now,
781 updated_at: now,
782 };
783 let initial_state = AgentOutput::PabloV1 {
784 messages: messages.to_vec(),
785 node_states: serde_json::json!({}),
786 };
787 db::create_checkpoint(&self.db, session_id, &checkpoint, &initial_state).await?;
788
789 db::get_checkpoint(&self.db, checkpoint_id).await?
790 };
791
792 Ok(current_checkpoint)
793 }
794
795 async fn update_session(
796 &self,
797 checkpoint_info: &RunAgentOutput,
798 new_messages: Vec<ChatMessage>,
799 ) -> Result<RunAgentOutput, String> {
800 let now = chrono::Utc::now();
801 let complete_checkpoint = AgentCheckpointListItem {
802 id: Uuid::new_v4(),
803 status: AgentStatus::Complete,
804 execution_depth: checkpoint_info.checkpoint.execution_depth + 1,
805 parent: Some(AgentParentCheckpoint {
806 id: checkpoint_info.checkpoint.id,
807 }),
808 created_at: now,
809 updated_at: now,
810 };
811
812 let mut new_state = checkpoint_info.output.clone();
813 new_state.set_messages(new_messages);
814
815 db::create_checkpoint(
816 &self.db,
817 checkpoint_info.session.id,
818 &complete_checkpoint,
819 &new_state,
820 )
821 .await?;
822
823 Ok(RunAgentOutput {
824 checkpoint: complete_checkpoint,
825 session: checkpoint_info.session.clone(),
826 output: new_state,
827 })
828 }
829
830 async fn generate_session_title(&self, messages: &[ChatMessage]) -> Result<String, String> {
831 let llm_config = self.get_llm_config();
832
833 let llm_model = if let Some(eco_model) = &self.model_options.eco_model {
834 eco_model.clone()
835 } else if llm_config.openai_config.is_some() {
836 LLMModel::OpenAI(OpenAIModel::GPT5Mini)
837 } else if llm_config.anthropic_config.is_some() {
838 LLMModel::Anthropic(AnthropicModel::Claude45Haiku)
839 } else if llm_config.gemini_config.is_some() {
840 LLMModel::Gemini(GeminiModel::Gemini25Flash)
841 } else {
842 return Err("No LLM config found".to_string());
843 };
844
845 let messages = vec![
846 LLMMessage {
847 role: "system".to_string(),
848 content: LLMMessageContent::String(TITLE_GENERATOR_PROMPT.into()),
849 },
850 LLMMessage {
851 role: "user".to_string(),
852 content: LLMMessageContent::String(
853 messages
854 .iter()
855 .map(|msg| {
856 msg.content
857 .as_ref()
858 .unwrap_or(&MessageContent::String("".to_string()))
859 .to_string()
860 })
861 .collect(),
862 ),
863 },
864 ];
865
866 let input = LLMInput {
867 model: llm_model,
868 messages,
869 max_tokens: 100,
870 tools: None,
871 provider_options: None,
872 };
873
874 let stakai_client = StakAIClient::new(&llm_config)
875 .map_err(|e| format!("Failed to create StakAI client: {}", e))?;
876 let response = stakai_client.chat(input).await.map_err(|e| e.to_string())?;
877
878 Ok(response.choices[0].message.content.to_string())
879 }
880}
881
882async fn analyze_search_query(
883 llm_config: &LLMProviderConfig,
884 model: &LLMModel,
885 query: &str,
886) -> Result<AnalysisResult, String> {
887 let system_prompt = r#"You are an expert search query analyzer specializing in technical documentation retrieval.
888
889## Your Task
890
891Analyze the user's search query to:
8921. Identify the specific types of documentation needed
8932. Reformulate the query for optimal search engine results
894
895## Guidelines for Required Documentation
896
897Identify specific documentation types such as:
898- API references and specifications
899- Installation/setup guides
900- Configuration documentation
901- Tutorials and getting started guides
902- Troubleshooting guides
903- Architecture/design documents
904- CLI/command references
905- SDK/library documentation
906
907## Guidelines for Query Reformulation
908
909Create an optimized search query that:
910- Uses specific technical terminology
911- Includes relevant keywords (e.g., "documentation", "guide", "API")
912- Removes ambiguous or filler words
913- Targets authoritative sources when possible
914- Is concise but comprehensive (5-10 words ideal)
915
916## Response Format
917
918Respond ONLY with valid XML in this exact structure:
919
920<analysis>
921 <required_documentation>
922 <item>specific documentation type needed</item>
923 </required_documentation>
924 <reformulated_query>optimized search query string</reformulated_query>
925</analysis>"#;
926
927 let user_prompt = format!(
928 r#"<user_query>{}</user_query>
929
930Analyze this query and provide the required documentation types and an optimized search query."#,
931 query
932 );
933
934 let input = LLMInput {
935 model: model.clone(),
936 messages: vec![
937 LLMMessage {
938 role: Role::System.to_string(),
939 content: LLMMessageContent::String(system_prompt.to_string()),
940 },
941 LLMMessage {
942 role: Role::User.to_string(),
943 content: LLMMessageContent::String(user_prompt.to_string()),
944 },
945 ],
946 max_tokens: 2000,
947 tools: None,
948 provider_options: None,
949 };
950
951 let stakai_client = StakAIClient::new(llm_config)
952 .map_err(|e| format!("Failed to create StakAI client: {}", e))?;
953 let response = stakai_client.chat(input).await.map_err(|e| e.to_string())?;
954
955 let content = response.choices[0].message.content.to_string();
956
957 parse_analysis_xml(&content)
958}
959
960fn parse_analysis_xml(xml: &str) -> Result<AnalysisResult, String> {
961 let extract_tag = |tag: &str| -> Option<String> {
962 let start_tag = format!("<{}>", tag);
963 let end_tag = format!("</{}>", tag);
964 xml.find(&start_tag).and_then(|start| {
965 let content_start = start + start_tag.len();
966 xml[content_start..]
967 .find(&end_tag)
968 .map(|end| xml[content_start..content_start + end].trim().to_string())
969 })
970 };
971
972 let extract_all_tags = |tag: &str| -> Vec<String> {
973 let start_tag = format!("<{}>", tag);
974 let end_tag = format!("</{}>", tag);
975 let mut results = Vec::new();
976 let mut search_start = 0;
977
978 while let Some(start) = xml[search_start..].find(&start_tag) {
979 let abs_start = search_start + start + start_tag.len();
980 if let Some(end) = xml[abs_start..].find(&end_tag) {
981 results.push(xml[abs_start..abs_start + end].trim().to_string());
982 search_start = abs_start + end + end_tag.len();
983 } else {
984 break;
985 }
986 }
987 results
988 };
989
990 let required_documentation = extract_all_tags("item");
991 let reformulated_query =
992 extract_tag("reformulated_query").ok_or("Failed to extract reformulated_query from XML")?;
993
994 Ok(AnalysisResult {
995 required_documentation,
996 reformulated_query,
997 })
998}
999
1000async fn validate_search_docs(
1001 llm_config: &LLMProviderConfig,
1002 model: &LLMModel,
1003 docs: &[ScrapedContent],
1004 query: &str,
1005 required_documentation: &[String],
1006 previous_queries: &[String],
1007 accumulated_needed_urls: &[String],
1008) -> Result<ValidationResult, String> {
1009 let docs_preview = docs
1010 .iter()
1011 .enumerate()
1012 .take(10)
1013 .map(|(i, r)| {
1014 format!(
1015 "<doc index=\"{}\">\n <title>{}</title>\n <url>{}</url>\n</doc>",
1016 i + 1,
1017 r.title.clone().unwrap_or_else(|| "Untitled".to_string()),
1018 r.url
1019 )
1020 })
1021 .collect::<Vec<_>>()
1022 .join("\n");
1023
1024 let required_docs_formatted = required_documentation
1025 .iter()
1026 .map(|d| format!(" <item>{}</item>", d))
1027 .collect::<Vec<_>>()
1028 .join("\n");
1029
1030 let previous_queries_formatted = previous_queries
1031 .iter()
1032 .map(|q| format!(" <query>{}</query>", q))
1033 .collect::<Vec<_>>()
1034 .join("\n");
1035
1036 let accumulated_urls_formatted = accumulated_needed_urls
1037 .iter()
1038 .map(|u| format!(" <url>{}</url>", u))
1039 .collect::<Vec<_>>()
1040 .join("\n");
1041
1042 let system_prompt = r#"You are an expert search result validator. Your task is to evaluate whether search results adequately satisfy a documentation query.
1043
1044## Evaluation Criteria
1045
1046For each search result, assess:
10471. **Relevance**: Does the document directly address the required documentation topics?
10482. **Authority**: Is this an official source, documentation site, or authoritative reference?
10493. **Completeness**: Does it provide comprehensive information, not just passing mentions?
10504. **Freshness**: For technical docs, prefer current/maintained sources over outdated ones.
1051
1052## Decision Guidelines
1053
1054Mark results as SATISFIED when:
1055- All required documentation topics have at least one authoritative source
1056- The sources provide actionable, detailed information
1057- No critical gaps remain in coverage
1058
1059Suggest a NEW QUERY when:
1060- Key topics are missing from results
1061- Results are too general or tangential
1062- A more specific query would yield better results
1063- Previous queries haven't addressed certain requirements
1064
1065## Response Format
1066
1067Respond ONLY with valid XML in this exact structure:
1068
1069<validation>
1070 <is_satisfied>true or false</is_satisfied>
1071 <valid_docs>
1072 <doc><url>exact URL from results</url></doc>
1073 </valid_docs>
1074 <needed_urls>
1075 <url>specific URL pattern or domain still needed</url>
1076 </needed_urls>
1077 <new_query>refined search query if not satisfied, omit if satisfied</new_query>
1078 <reasoning>brief explanation of your assessment</reasoning>
1079</validation>"#;
1080
1081 let user_prompt = format!(
1082 r#"<search_context>
1083 <original_query>{}</original_query>
1084 <required_documentation>
1085{}
1086 </required_documentation>
1087 <previous_queries>
1088{}
1089 </previous_queries>
1090 <accumulated_needed_urls>
1091{}
1092 </accumulated_needed_urls>
1093</search_context>
1094
1095<current_results>
1096{}
1097</current_results>
1098
1099Evaluate these search results against the requirements. Which documents are valid and relevant? Is the documentation requirement satisfied? If not, what specific query would help find missing information?"#,
1100 query,
1101 if required_docs_formatted.is_empty() {
1102 " <item>None specified</item>".to_string()
1103 } else {
1104 required_docs_formatted
1105 },
1106 if previous_queries_formatted.is_empty() {
1107 " <query>None</query>".to_string()
1108 } else {
1109 previous_queries_formatted
1110 },
1111 if accumulated_urls_formatted.is_empty() {
1112 " <url>None</url>".to_string()
1113 } else {
1114 accumulated_urls_formatted
1115 },
1116 docs_preview
1117 );
1118
1119 let input = LLMInput {
1120 model: model.clone(),
1121 messages: vec![
1122 LLMMessage {
1123 role: Role::System.to_string(),
1124 content: LLMMessageContent::String(system_prompt.to_string()),
1125 },
1126 LLMMessage {
1127 role: Role::User.to_string(),
1128 content: LLMMessageContent::String(user_prompt.to_string()),
1129 },
1130 ],
1131 max_tokens: 4000,
1132 tools: None,
1133 provider_options: None,
1134 };
1135
1136 let stakai_client = StakAIClient::new(llm_config)
1137 .map_err(|e| format!("Failed to create StakAI client: {}", e))?;
1138 let response = stakai_client.chat(input).await.map_err(|e| e.to_string())?;
1139
1140 let content = response.choices[0].message.content.to_string();
1141
1142 let validation = parse_validation_xml(&content, docs)?;
1143
1144 Ok(validation)
1145}
1146
1147fn parse_validation_xml(xml: &str, docs: &[ScrapedContent]) -> Result<ValidationResult, String> {
1148 let extract_tag = |tag: &str| -> Option<String> {
1149 let start_tag = format!("<{}>", tag);
1150 let end_tag = format!("</{}>", tag);
1151 xml.find(&start_tag).and_then(|start| {
1152 let content_start = start + start_tag.len();
1153 xml[content_start..]
1154 .find(&end_tag)
1155 .map(|end| xml[content_start..content_start + end].trim().to_string())
1156 })
1157 };
1158
1159 let extract_all_tags = |tag: &str| -> Vec<String> {
1160 let start_tag = format!("<{}>", tag);
1161 let end_tag = format!("</{}>", tag);
1162 let mut results = Vec::new();
1163 let mut search_start = 0;
1164
1165 while let Some(start) = xml[search_start..].find(&start_tag) {
1166 let abs_start = search_start + start + start_tag.len();
1167 if let Some(end) = xml[abs_start..].find(&end_tag) {
1168 results.push(xml[abs_start..abs_start + end].trim().to_string());
1169 search_start = abs_start + end + end_tag.len();
1170 } else {
1171 break;
1172 }
1173 }
1174 results
1175 };
1176
1177 let is_satisfied = extract_tag("is_satisfied")
1178 .map(|s| s.to_lowercase() == "true")
1179 .unwrap_or(false);
1180
1181 let valid_urls: Vec<String> = extract_all_tags("url")
1182 .into_iter()
1183 .filter(|url| docs.iter().any(|d| d.url == *url))
1184 .collect();
1185
1186 let valid_docs: Vec<ScrapedContent> = valid_urls
1187 .iter()
1188 .filter_map(|url| docs.iter().find(|d| d.url == *url).cloned())
1189 .collect();
1190
1191 let needed_urls: Vec<String> = extract_all_tags("url")
1192 .into_iter()
1193 .filter(|url| !docs.iter().any(|d| d.url == *url))
1194 .collect();
1195
1196 let new_query = extract_tag("new_query").filter(|q| !q.is_empty() && q != "omit if satisfied");
1197
1198 Ok(ValidationResult {
1199 is_satisfied,
1200 valid_docs,
1201 needed_urls,
1202 new_query,
1203 })
1204}
1205
1206fn get_search_model(
1207 llm_config: &LLMProviderConfig,
1208 eco_model: Option<LLMModel>,
1209 smart_model: Option<LLMModel>,
1210) -> LLMModel {
1211 let base_model = eco_model.or(smart_model);
1212
1213 match base_model {
1214 Some(LLMModel::OpenAI(_)) => LLMModel::OpenAI(OpenAIModel::O4Mini),
1215 Some(LLMModel::Anthropic(_)) => LLMModel::Anthropic(AnthropicModel::Claude45Haiku),
1216 Some(LLMModel::Gemini(_)) => LLMModel::Gemini(GeminiModel::Gemini3Flash),
1217 Some(LLMModel::Custom(model)) => LLMModel::Custom(model),
1218 None => {
1219 if llm_config.openai_config.is_some() {
1220 LLMModel::OpenAI(OpenAIModel::O4Mini)
1221 } else if llm_config.anthropic_config.is_some() {
1222 LLMModel::Anthropic(AnthropicModel::Claude45Haiku)
1223 } else if llm_config.gemini_config.is_some() {
1224 LLMModel::Gemini(GeminiModel::Gemini3Flash)
1225 } else {
1226 LLMModel::OpenAI(OpenAIModel::O4Mini)
1227 }
1228 }
1229 }
1230}