1use crate::AgentProvider;
9use crate::local::db;
10use crate::models::*;
11use crate::stakpak::{
12 CheckpointState, CreateCheckpointRequest, CreateSessionRequest, ListCheckpointsQuery,
13 ListSessionsQuery,
14};
15use async_trait::async_trait;
16use futures_util::Stream;
17use reqwest::header::HeaderMap;
18use rmcp::model::Content;
19use stakpak_shared::hooks::{HookContext, LifecycleEvent};
20use stakpak_shared::models::integrations::anthropic::AnthropicModel;
21use stakpak_shared::models::integrations::openai::{
22 AgentModel, ChatCompletionChoice, ChatCompletionResponse, ChatCompletionStreamChoice,
23 ChatCompletionStreamResponse, ChatMessage, FinishReason, MessageContent, Role, Tool,
24};
25use stakpak_shared::models::llm::{
26 GenerationDelta, LLMInput, LLMMessage, LLMMessageContent, LLMModel, LLMStreamInput,
27};
28use stakpak_shared::models::stakai_adapter::get_stakai_model_string;
29use std::pin::Pin;
30use tokio::sync::mpsc;
31use uuid::Uuid;
32
33use super::AgentClient;
34
35#[derive(Debug)]
40pub(crate) enum StreamMessage {
41 Delta(GenerationDelta),
42 Ctx(Box<HookContext<AgentState>>),
43}
44
45#[async_trait]
50impl AgentProvider for AgentClient {
51 async fn get_my_account(&self) -> Result<GetMyAccountResponse, String> {
56 if let Some(api) = &self.stakpak_api {
57 api.get_account().await
58 } else {
59 Ok(GetMyAccountResponse {
61 username: "local".to_string(),
62 id: "local".to_string(),
63 first_name: "local".to_string(),
64 last_name: "local".to_string(),
65 email: "local@stakpak.dev".to_string(),
66 scope: None,
67 })
68 }
69 }
70
71 async fn get_billing_info(
72 &self,
73 account_username: &str,
74 ) -> Result<stakpak_shared::models::billing::BillingResponse, String> {
75 if let Some(api) = &self.stakpak_api {
76 api.get_billing(account_username).await
77 } else {
78 Err("Billing info not available without Stakpak API key".to_string())
79 }
80 }
81
82 async fn list_rulebooks(&self) -> Result<Vec<ListRuleBook>, String> {
87 if let Some(api) = &self.stakpak_api {
88 api.list_rulebooks().await
89 } else {
90 let client = stakpak_shared::tls_client::create_tls_client(
92 stakpak_shared::tls_client::TlsClientConfig::default()
93 .with_timeout(std::time::Duration::from_secs(30)),
94 )?;
95
96 let url = format!("{}/v1/rules", self.get_stakpak_api_endpoint());
97 let response = client.get(&url).send().await.map_err(|e| e.to_string())?;
98
99 if response.status().is_success() {
100 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
101 match serde_json::from_value::<ListRulebooksResponse>(value) {
102 Ok(resp) => Ok(resp.results),
103 Err(_) => Ok(vec![]),
104 }
105 } else {
106 Ok(vec![])
107 }
108 }
109 }
110
111 async fn get_rulebook_by_uri(&self, uri: &str) -> Result<RuleBook, String> {
112 if let Some(api) = &self.stakpak_api {
113 api.get_rulebook_by_uri(uri).await
114 } else {
115 let client = stakpak_shared::tls_client::create_tls_client(
117 stakpak_shared::tls_client::TlsClientConfig::default()
118 .with_timeout(std::time::Duration::from_secs(30)),
119 )?;
120
121 let encoded_uri = urlencoding::encode(uri);
122 let url = format!(
123 "{}/v1/rules/{}",
124 self.get_stakpak_api_endpoint(),
125 encoded_uri
126 );
127 let response = client.get(&url).send().await.map_err(|e| e.to_string())?;
128
129 if response.status().is_success() {
130 response.json().await.map_err(|e| e.to_string())
131 } else {
132 Err("Rulebook not found".to_string())
133 }
134 }
135 }
136
137 async fn create_rulebook(
138 &self,
139 uri: &str,
140 description: &str,
141 content: &str,
142 tags: Vec<String>,
143 visibility: Option<RuleBookVisibility>,
144 ) -> Result<CreateRuleBookResponse, String> {
145 if let Some(api) = &self.stakpak_api {
146 api.create_rulebook(&CreateRuleBookInput {
147 uri: uri.to_string(),
148 description: description.to_string(),
149 content: content.to_string(),
150 tags,
151 visibility,
152 })
153 .await
154 } else {
155 Err("Creating rulebooks requires Stakpak API key".to_string())
156 }
157 }
158
159 async fn delete_rulebook(&self, uri: &str) -> Result<(), String> {
160 if let Some(api) = &self.stakpak_api {
161 api.delete_rulebook(uri).await
162 } else {
163 Err("Deleting rulebooks requires Stakpak API key".to_string())
164 }
165 }
166
167 async fn list_agent_sessions(&self) -> Result<Vec<AgentSession>, String> {
172 if let Some(api) = &self.stakpak_api {
173 let response = api.list_sessions(&ListSessionsQuery::default()).await?;
175 Ok(response
176 .sessions
177 .into_iter()
178 .map(|s| AgentSession {
179 id: s.id,
180 title: s.title,
181 agent_id: AgentID::PabloV1,
182 visibility: match s.visibility {
183 crate::stakpak::SessionVisibility::Public => AgentSessionVisibility::Public,
184 crate::stakpak::SessionVisibility::Private => {
185 AgentSessionVisibility::Private
186 }
187 },
188 checkpoints: vec![], created_at: s.created_at,
190 updated_at: s.updated_at,
191 })
192 .collect())
193 } else {
194 db::list_sessions(&self.local_db).await
196 }
197 }
198
199 async fn get_agent_session(&self, session_id: Uuid) -> Result<AgentSession, String> {
200 if let Some(api) = &self.stakpak_api {
201 let response = api.get_session(session_id).await?;
202 let s = response.session;
203
204 let checkpoints_response = api
206 .list_checkpoints(session_id, &ListCheckpointsQuery::default())
207 .await?;
208
209 Ok(AgentSession {
210 id: s.id,
211 title: s.title,
212 agent_id: AgentID::PabloV1,
213 visibility: match s.visibility {
214 crate::stakpak::SessionVisibility::Public => AgentSessionVisibility::Public,
215 crate::stakpak::SessionVisibility::Private => AgentSessionVisibility::Private,
216 },
217 checkpoints: checkpoints_response
218 .checkpoints
219 .into_iter()
220 .enumerate()
221 .map(|(i, c)| AgentCheckpointListItem {
222 id: c.id,
223 status: AgentStatus::Complete,
224 execution_depth: i,
225 parent: c.parent_id.map(|id| AgentParentCheckpoint { id }),
226 created_at: c.created_at,
227 updated_at: c.updated_at,
228 })
229 .collect(),
230 created_at: s.created_at,
231 updated_at: s.updated_at,
232 })
233 } else {
234 db::get_session(&self.local_db, session_id).await
235 }
236 }
237
238 async fn get_agent_session_stats(
239 &self,
240 _session_id: Uuid,
241 ) -> Result<AgentSessionStats, String> {
242 Ok(AgentSessionStats::default())
244 }
245
246 async fn get_agent_checkpoint(&self, checkpoint_id: Uuid) -> Result<RunAgentOutput, String> {
247 if let Some(api) = &self.stakpak_api {
248 let response = api.get_checkpoint(checkpoint_id).await?;
249 let c = response.checkpoint;
250
251 let session_response = api.get_session(c.session_id).await?;
253 let s = session_response.session;
254
255 Ok(RunAgentOutput {
256 checkpoint: AgentCheckpointListItem {
257 id: c.id,
258 status: AgentStatus::Complete,
259 execution_depth: 0, parent: c.parent_id.map(|id| AgentParentCheckpoint { id }),
261 created_at: c.created_at,
262 updated_at: c.updated_at,
263 },
264 session: AgentSessionListItem {
265 id: s.id,
266 agent_id: AgentID::PabloV1,
267 visibility: match s.visibility {
268 crate::stakpak::SessionVisibility::Public => AgentSessionVisibility::Public,
269 crate::stakpak::SessionVisibility::Private => {
270 AgentSessionVisibility::Private
271 }
272 },
273 created_at: s.created_at,
274 updated_at: s.updated_at,
275 },
276 output: AgentOutput::PabloV1 {
277 messages: c.state.messages,
278 node_states: serde_json::json!({}),
279 },
280 })
281 } else {
282 db::get_checkpoint(&self.local_db, checkpoint_id).await
283 }
284 }
285
286 async fn get_agent_session_latest_checkpoint(
287 &self,
288 session_id: Uuid,
289 ) -> Result<RunAgentOutput, String> {
290 if let Some(api) = &self.stakpak_api {
291 let session_response = api.get_session(session_id).await?;
293 let s = session_response.session;
294
295 if let Some(active_checkpoint) = s.active_checkpoint {
296 Ok(RunAgentOutput {
297 checkpoint: AgentCheckpointListItem {
298 id: active_checkpoint.id,
299 status: AgentStatus::Complete,
300 execution_depth: 0,
301 parent: active_checkpoint
302 .parent_id
303 .map(|id| AgentParentCheckpoint { id }),
304 created_at: active_checkpoint.created_at,
305 updated_at: active_checkpoint.updated_at,
306 },
307 session: AgentSessionListItem {
308 id: s.id,
309 agent_id: AgentID::PabloV1,
310 visibility: match s.visibility {
311 crate::stakpak::SessionVisibility::Public => {
312 AgentSessionVisibility::Public
313 }
314 crate::stakpak::SessionVisibility::Private => {
315 AgentSessionVisibility::Private
316 }
317 },
318 created_at: s.created_at,
319 updated_at: s.updated_at,
320 },
321 output: AgentOutput::PabloV1 {
322 messages: active_checkpoint.state.messages,
323 node_states: serde_json::json!({}),
324 },
325 })
326 } else {
327 Err("Session has no active checkpoint".to_string())
328 }
329 } else {
330 db::get_latest_checkpoint(&self.local_db, session_id).await
331 }
332 }
333
334 async fn chat_completion(
339 &self,
340 model: AgentModel,
341 messages: Vec<ChatMessage>,
342 tools: Option<Vec<Tool>>,
343 ) -> Result<ChatCompletionResponse, String> {
344 let mut ctx = HookContext::new(None, AgentState::new(model, messages, tools));
345
346 self.hook_registry
348 .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
349 .await
350 .map_err(|e| e.to_string())?
351 .ok()?;
352
353 let current_checkpoint = self.initialize_session(&ctx.state.messages).await?;
355 ctx.set_session_id(current_checkpoint.session.id);
356
357 let new_message = self.run_agent_completion(&mut ctx, None).await?;
359 ctx.state.append_new_message(new_message.clone());
360
361 let result = self
363 .update_session(¤t_checkpoint, ctx.state.messages.clone())
364 .await?;
365 let checkpoint_created_at = result.checkpoint.created_at.timestamp() as u64;
366 ctx.set_new_checkpoint_id(result.checkpoint.id);
367
368 self.hook_registry
370 .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
371 .await
372 .map_err(|e| e.to_string())?
373 .ok()?;
374
375 Ok(ChatCompletionResponse {
376 id: ctx.new_checkpoint_id.unwrap().to_string(),
377 object: "chat.completion".to_string(),
378 created: checkpoint_created_at,
379 model: ctx
380 .state
381 .llm_input
382 .as_ref()
383 .map(|llm_input| llm_input.model.clone().to_string())
384 .unwrap_or_default(),
385 choices: vec![ChatCompletionChoice {
386 index: 0,
387 message: ctx.state.messages.last().cloned().unwrap(),
388 logprobs: None,
389 finish_reason: FinishReason::Stop,
390 }],
391 usage: ctx
392 .state
393 .llm_output
394 .as_ref()
395 .map(|u| u.usage.clone())
396 .unwrap_or_default(),
397 system_fingerprint: None,
398 metadata: None,
399 })
400 }
401
402 async fn chat_completion_stream(
403 &self,
404 model: AgentModel,
405 messages: Vec<ChatMessage>,
406 tools: Option<Vec<Tool>>,
407 _headers: Option<HeaderMap>,
408 ) -> Result<
409 (
410 Pin<
411 Box<dyn Stream<Item = Result<ChatCompletionStreamResponse, ApiStreamError>> + Send>,
412 >,
413 Option<String>,
414 ),
415 String,
416 > {
417 let mut ctx = HookContext::new(None, AgentState::new(model, messages, tools));
418
419 self.hook_registry
421 .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
422 .await
423 .map_err(|e| e.to_string())?
424 .ok()?;
425
426 let current_checkpoint = self.initialize_session(&ctx.state.messages).await?;
428 ctx.set_session_id(current_checkpoint.session.id);
429
430 let (tx, mut rx) = mpsc::channel::<Result<StreamMessage, String>>(100);
431
432 let _ = tx
434 .send(Ok(StreamMessage::Delta(GenerationDelta::Content {
435 content: format!(
436 "\n<checkpoint_id>{}</checkpoint_id>\n",
437 current_checkpoint.checkpoint.id
438 ),
439 })))
440 .await;
441
442 let client = self.clone();
444 let mut ctx_clone = ctx.clone();
445
446 tokio::spawn(async move {
450 if tx.is_closed() {
452 return;
453 }
454
455 let result = client
456 .run_agent_completion(&mut ctx_clone, Some(tx.clone()))
457 .await;
458
459 match result {
460 Err(e) => {
461 let _ = tx.send(Err(e)).await;
462 }
463 Ok(new_message) => {
464 if tx.is_closed() {
466 return;
467 }
468
469 ctx_clone.state.append_new_message(new_message.clone());
470 if tx
471 .send(Ok(StreamMessage::Ctx(Box::new(ctx_clone.clone()))))
472 .await
473 .is_err()
474 {
475 return;
477 }
478
479 if tx.is_closed() {
481 return;
482 }
483
484 let output = client
485 .update_session(¤t_checkpoint, ctx_clone.state.messages.clone())
486 .await;
487
488 match output {
489 Err(e) => {
490 let _ = tx.send(Err(e)).await;
491 }
492 Ok(output) => {
493 ctx_clone.set_new_checkpoint_id(output.checkpoint.id);
494 let _ = tx.send(Ok(StreamMessage::Ctx(Box::new(ctx_clone)))).await;
495 let _ = tx
496 .send(Ok(StreamMessage::Delta(GenerationDelta::Content {
497 content: format!(
498 "\n<checkpoint_id>{}</checkpoint_id>\n",
499 output.checkpoint.id
500 ),
501 })))
502 .await;
503 }
504 }
505 }
506 }
507 });
508
509 let hook_registry = self.hook_registry.clone();
510 let stream = async_stream::stream! {
511 while let Some(delta_result) = rx.recv().await {
512 match delta_result {
513 Ok(delta) => match delta {
514 StreamMessage::Ctx(updated_ctx) => {
515 ctx = *updated_ctx;
516 }
517 StreamMessage::Delta(delta) => {
518 yield Ok(ChatCompletionStreamResponse {
519 id: ctx.request_id.to_string(),
520 object: "chat.completion.chunk".to_string(),
521 created: chrono::Utc::now().timestamp() as u64,
522 model: ctx.state.llm_input.as_ref().map(|llm_input| llm_input.model.clone().to_string()).unwrap_or_default(),
523 choices: vec![ChatCompletionStreamChoice {
524 index: 0,
525 delta: delta.into(),
526 finish_reason: None,
527 }],
528 usage: ctx.state.llm_output.as_ref().map(|u| u.usage.clone()),
529 metadata: None,
530 })
531 }
532 }
533 Err(e) => yield Err(ApiStreamError::Unknown(e)),
534 }
535 }
536
537 hook_registry
539 .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
540 .await
541 .map_err(|e| e.to_string())?
542 .ok()?;
543 };
544
545 Ok((Box::pin(stream), None))
546 }
547
548 async fn cancel_stream(&self, request_id: String) -> Result<(), String> {
549 if let Some(api) = &self.stakpak_api {
550 api.cancel_request(&request_id).await
551 } else {
552 Ok(())
554 }
555 }
556
557 async fn search_docs(&self, input: &SearchDocsRequest) -> Result<Vec<Content>, String> {
562 if let Some(api) = &self.stakpak_api {
563 api.search_docs(&crate::stakpak::SearchDocsRequest {
564 keywords: input.keywords.clone(),
565 exclude_keywords: input.exclude_keywords.clone(),
566 limit: input.limit,
567 })
568 .await
569 } else {
570 use stakpak_shared::models::integrations::search_service::*;
572
573 let config = SearchServicesOrchestrator::start()
574 .await
575 .map_err(|e| e.to_string())?;
576
577 let api_url = format!("http://localhost:{}", config.api_port);
578 let search_client = SearchClient::new(api_url);
579
580 let search_results = search_client
581 .search_and_scrape(input.keywords.clone(), None)
582 .await
583 .map_err(|e| e.to_string())?;
584
585 if search_results.is_empty() {
586 return Ok(vec![Content::text("No results found".to_string())]);
587 }
588
589 Ok(search_results
590 .into_iter()
591 .map(|result| {
592 let content = result.content.unwrap_or_default();
593 Content::text(format!("URL: {}\nContent: {}", result.url, content))
594 })
595 .collect())
596 }
597 }
598
599 async fn memorize_session(&self, checkpoint_id: Uuid) -> Result<(), String> {
604 if let Some(api) = &self.stakpak_api {
605 api.memorize_session(checkpoint_id).await
606 } else {
607 Ok(())
609 }
610 }
611
612 async fn search_memory(&self, input: &SearchMemoryRequest) -> Result<Vec<Content>, String> {
613 if let Some(api) = &self.stakpak_api {
614 api.search_memory(&crate::stakpak::SearchMemoryRequest {
615 keywords: input.keywords.clone(),
616 start_time: input.start_time,
617 end_time: input.end_time,
618 })
619 .await
620 } else {
621 Ok(vec![])
623 }
624 }
625
626 async fn slack_read_messages(
631 &self,
632 input: &SlackReadMessagesRequest,
633 ) -> Result<Vec<Content>, String> {
634 if let Some(api) = &self.stakpak_api {
635 api.slack_read_messages(&crate::stakpak::SlackReadMessagesRequest {
636 channel: input.channel.clone(),
637 limit: input.limit,
638 })
639 .await
640 } else {
641 Err("Slack integration requires Stakpak API key".to_string())
642 }
643 }
644
645 async fn slack_read_replies(
646 &self,
647 input: &SlackReadRepliesRequest,
648 ) -> Result<Vec<Content>, String> {
649 if let Some(api) = &self.stakpak_api {
650 api.slack_read_replies(&crate::stakpak::SlackReadRepliesRequest {
651 channel: input.channel.clone(),
652 ts: input.ts.clone(),
653 })
654 .await
655 } else {
656 Err("Slack integration requires Stakpak API key".to_string())
657 }
658 }
659
660 async fn slack_send_message(
661 &self,
662 input: &SlackSendMessageRequest,
663 ) -> Result<Vec<Content>, String> {
664 if let Some(api) = &self.stakpak_api {
665 api.slack_send_message(&crate::stakpak::SlackSendMessageRequest {
666 channel: input.channel.clone(),
667 mrkdwn_text: input.mrkdwn_text.clone(),
668 thread_ts: input.thread_ts.clone(),
669 })
670 .await
671 } else {
672 Err("Slack integration requires Stakpak API key".to_string())
673 }
674 }
675}
676
677const TITLE_GENERATOR_PROMPT: &str =
682 include_str!("../local/prompts/session_title_generator.v1.txt");
683
684impl AgentClient {
685 pub(crate) async fn initialize_session(
687 &self,
688 messages: &[ChatMessage],
689 ) -> Result<RunAgentOutput, String> {
690 if messages.is_empty() {
691 return Err("At least one message is required".to_string());
692 }
693
694 let checkpoint_id = ChatMessage::last_server_message(messages).and_then(|message| {
696 message
697 .content
698 .as_ref()
699 .and_then(|content| content.extract_checkpoint_id())
700 });
701
702 if let Some(checkpoint_id) = checkpoint_id {
703 return self.get_agent_checkpoint(checkpoint_id).await;
705 }
706
707 let title = match self.generate_session_title(messages).await {
710 Ok(title) => title,
711 Err(_) => {
712 messages
714 .iter()
715 .find(|m| m.role == Role::User)
716 .and_then(|m| m.content.as_ref())
717 .map(|c| {
718 let text = c.to_string();
719 text.split_whitespace()
720 .take(5)
721 .collect::<Vec<_>>()
722 .join(" ")
723 })
724 .unwrap_or_else(|| "New Session".to_string())
725 }
726 };
727
728 let cwd = std::env::current_dir()
730 .ok()
731 .map(|p| p.to_string_lossy().to_string());
732
733 if let Some(api) = &self.stakpak_api {
734 let mut session_request = CreateSessionRequest::new(
736 title,
737 CheckpointState {
738 messages: messages.to_vec(),
739 },
740 );
741 if let Some(cwd) = cwd {
742 session_request = session_request.with_cwd(cwd);
743 }
744 let response = api.create_session(&session_request).await?;
745
746 Ok(RunAgentOutput {
747 checkpoint: AgentCheckpointListItem {
748 id: response.checkpoint.id,
749 status: AgentStatus::Complete,
750 execution_depth: 0,
751 parent: response
752 .checkpoint
753 .parent_id
754 .map(|id| AgentParentCheckpoint { id }),
755 created_at: response.checkpoint.created_at,
756 updated_at: response.checkpoint.updated_at,
757 },
758 session: AgentSessionListItem {
759 id: response.session_id,
760 agent_id: AgentID::PabloV1,
761 visibility: AgentSessionVisibility::Private,
762 created_at: response.checkpoint.created_at,
763 updated_at: response.checkpoint.updated_at,
764 },
765 output: AgentOutput::PabloV1 {
766 messages: messages.to_vec(),
767 node_states: serde_json::json!({}),
768 },
769 })
770 } else {
771 let now = chrono::Utc::now();
773 let session_id = Uuid::new_v4();
774 let session = AgentSession {
775 id: session_id,
776 title,
777 agent_id: AgentID::PabloV1,
778 visibility: AgentSessionVisibility::Private,
779 created_at: now,
780 updated_at: now,
781 checkpoints: vec![],
782 };
783 db::create_session(&self.local_db, &session).await?;
784
785 let checkpoint_id = Uuid::new_v4();
786 let checkpoint = AgentCheckpointListItem {
787 id: checkpoint_id,
788 status: AgentStatus::Complete,
789 execution_depth: 0,
790 parent: None,
791 created_at: now,
792 updated_at: now,
793 };
794 let initial_state = AgentOutput::PabloV1 {
795 messages: messages.to_vec(),
796 node_states: serde_json::json!({}),
797 };
798 db::create_checkpoint(&self.local_db, session_id, &checkpoint, &initial_state).await?;
799
800 db::get_checkpoint(&self.local_db, checkpoint_id).await
801 }
802 }
803
804 pub(crate) async fn update_session(
806 &self,
807 checkpoint_info: &RunAgentOutput,
808 new_messages: Vec<ChatMessage>,
809 ) -> Result<RunAgentOutput, String> {
810 if let Some(api) = &self.stakpak_api {
811 let checkpoint_request = CreateCheckpointRequest::new(CheckpointState {
813 messages: new_messages.clone(),
814 })
815 .with_parent(checkpoint_info.checkpoint.id);
816
817 let response = api
818 .create_checkpoint(checkpoint_info.session.id, &checkpoint_request)
819 .await?;
820
821 Ok(RunAgentOutput {
822 checkpoint: AgentCheckpointListItem {
823 id: response.checkpoint.id,
824 status: AgentStatus::Complete,
825 execution_depth: checkpoint_info.checkpoint.execution_depth + 1,
826 parent: Some(AgentParentCheckpoint {
827 id: checkpoint_info.checkpoint.id,
828 }),
829 created_at: response.checkpoint.created_at,
830 updated_at: response.checkpoint.updated_at,
831 },
832 session: checkpoint_info.session.clone(),
833 output: AgentOutput::PabloV1 {
834 messages: new_messages,
835 node_states: serde_json::json!({}),
836 },
837 })
838 } else {
839 let now = chrono::Utc::now();
841 let complete_checkpoint = AgentCheckpointListItem {
842 id: Uuid::new_v4(),
843 status: AgentStatus::Complete,
844 execution_depth: checkpoint_info.checkpoint.execution_depth + 1,
845 parent: Some(AgentParentCheckpoint {
846 id: checkpoint_info.checkpoint.id,
847 }),
848 created_at: now,
849 updated_at: now,
850 };
851
852 let new_state = AgentOutput::PabloV1 {
853 messages: new_messages.clone(),
854 node_states: serde_json::json!({}),
855 };
856
857 db::create_checkpoint(
858 &self.local_db,
859 checkpoint_info.session.id,
860 &complete_checkpoint,
861 &new_state,
862 )
863 .await?;
864
865 Ok(RunAgentOutput {
866 checkpoint: complete_checkpoint,
867 session: checkpoint_info.session.clone(),
868 output: new_state,
869 })
870 }
871 }
872
873 pub(crate) async fn run_agent_completion(
875 &self,
876 ctx: &mut HookContext<AgentState>,
877 stream_channel_tx: Option<mpsc::Sender<Result<StreamMessage, String>>>,
878 ) -> Result<ChatMessage, String> {
879 self.hook_registry
881 .execute_hooks(ctx, &LifecycleEvent::BeforeInference)
882 .await
883 .map_err(|e| e.to_string())?
884 .ok()?;
885
886 let mut input = if let Some(llm_input) = ctx.state.llm_input.clone() {
887 llm_input
888 } else {
889 return Err(
890 "LLM input not found, make sure to register a context hook before inference"
891 .to_string(),
892 );
893 };
894
895 if let Some(session_id) = ctx.session_id {
897 let headers = input
898 .headers
899 .get_or_insert_with(std::collections::HashMap::new);
900 headers.insert("X-Session-Id".to_string(), session_id.to_string());
901 }
902
903 let (response_message, usage) = if let Some(tx) = stream_channel_tx {
904 let (internal_tx, mut internal_rx) = mpsc::channel::<GenerationDelta>(100);
906 let stream_input = LLMStreamInput {
907 model: input.model,
908 messages: input.messages,
909 max_tokens: input.max_tokens,
910 tools: input.tools,
911 stream_channel_tx: internal_tx,
912 provider_options: input.provider_options,
913 headers: input.headers,
914 };
915
916 let stakai = self.stakai.clone();
917 let chat_future = async move {
918 stakai
919 .chat_stream(stream_input)
920 .await
921 .map_err(|e| e.to_string())
922 };
923
924 let receive_future = async move {
925 while let Some(delta) = internal_rx.recv().await {
926 if tx.send(Ok(StreamMessage::Delta(delta))).await.is_err() {
927 break;
928 }
929 }
930 };
931
932 let (chat_result, _) = tokio::join!(chat_future, receive_future);
933 let response = chat_result?;
934 (response.choices[0].message.clone(), response.usage)
935 } else {
936 let response = self.stakai.chat(input).await.map_err(|e| e.to_string())?;
938 (response.choices[0].message.clone(), response.usage)
939 };
940
941 ctx.state.set_llm_output(response_message, usage);
942
943 self.hook_registry
945 .execute_hooks(ctx, &LifecycleEvent::AfterInference)
946 .await
947 .map_err(|e| e.to_string())?
948 .ok()?;
949
950 let llm_output = ctx
951 .state
952 .llm_output
953 .as_ref()
954 .ok_or_else(|| "LLM output is missing from state".to_string())?;
955
956 Ok(ChatMessage::from(llm_output))
957 }
958
959 async fn generate_session_title(&self, messages: &[ChatMessage]) -> Result<String, String> {
961 let llm_model = if let Some(eco_model) = &self.model_options.eco_model {
962 eco_model.clone()
963 } else {
964 LLMModel::Anthropic(AnthropicModel::Claude45Haiku)
966 };
967
968 let model = if self.has_stakpak() {
970 let model_str = get_stakai_model_string(&llm_model);
972 let display_name = model_str
974 .rsplit('/')
975 .next()
976 .unwrap_or(&model_str)
977 .to_string();
978 LLMModel::Custom {
979 provider: "stakpak".to_string(),
980 model: model_str,
981 name: Some(display_name),
982 }
983 } else {
984 llm_model
985 };
986
987 let llm_messages = vec![
988 LLMMessage {
989 role: Role::System.to_string(),
990 content: LLMMessageContent::String(TITLE_GENERATOR_PROMPT.to_string()),
991 },
992 LLMMessage {
993 role: Role::User.to_string(),
994 content: LLMMessageContent::String(
995 messages
996 .iter()
997 .map(|msg| {
998 msg.content
999 .as_ref()
1000 .unwrap_or(&MessageContent::String("".to_string()))
1001 .to_string()
1002 })
1003 .collect(),
1004 ),
1005 },
1006 ];
1007
1008 let input = LLMInput {
1009 model,
1010 messages: llm_messages,
1011 max_tokens: 100,
1012 tools: None,
1013 provider_options: None,
1014 headers: None,
1015 };
1016
1017 let response = self.stakai.chat(input).await.map_err(|e| e.to_string())?;
1018
1019 Ok(response.choices[0].message.content.to_string())
1020 }
1021}