Skip to main content

stakpak_api/client/
provider.rs

1//! AgentProvider trait implementation for AgentClient
2//!
3//! Implements the unified provider interface with:
4//! - Stakpak-first routing when API key is present
5//! - Local fallback when Stakpak is unavailable
6//! - Hook registry integration for lifecycle events
7
8use crate::AgentProvider;
9use crate::local::db;
10use crate::models::*;
11use crate::stakpak::{
12    CheckpointState, CreateCheckpointRequest, CreateSessionRequest, ListCheckpointsQuery,
13    ListSessionsQuery,
14};
15use async_trait::async_trait;
16use futures_util::Stream;
17use reqwest::header::HeaderMap;
18use rmcp::model::Content;
19use stakpak_shared::hooks::{HookContext, LifecycleEvent};
20use stakpak_shared::models::integrations::anthropic::AnthropicModel;
21use stakpak_shared::models::integrations::openai::{
22    AgentModel, ChatCompletionChoice, ChatCompletionResponse, ChatCompletionStreamChoice,
23    ChatCompletionStreamResponse, ChatMessage, FinishReason, MessageContent, Role, Tool,
24};
25use stakpak_shared::models::llm::{
26    GenerationDelta, LLMInput, LLMMessage, LLMMessageContent, LLMModel, LLMStreamInput,
27};
28use stakpak_shared::models::stakai_adapter::get_stakai_model_string;
29use std::pin::Pin;
30use tokio::sync::mpsc;
31use uuid::Uuid;
32
33use super::AgentClient;
34
35// =============================================================================
36// Internal Message Types
37// =============================================================================
38
39#[derive(Debug)]
40pub(crate) enum StreamMessage {
41    Delta(GenerationDelta),
42    Ctx(Box<HookContext<AgentState>>),
43}
44
45// =============================================================================
46// AgentProvider Implementation
47// =============================================================================
48
49#[async_trait]
50impl AgentProvider for AgentClient {
51    // =========================================================================
52    // Account
53    // =========================================================================
54
55    async fn get_my_account(&self) -> Result<GetMyAccountResponse, String> {
56        if let Some(api) = &self.stakpak_api {
57            api.get_account().await
58        } else {
59            // Local stub
60            Ok(GetMyAccountResponse {
61                username: "local".to_string(),
62                id: "local".to_string(),
63                first_name: "local".to_string(),
64                last_name: "local".to_string(),
65                email: "local@stakpak.dev".to_string(),
66                scope: None,
67            })
68        }
69    }
70
71    async fn get_billing_info(
72        &self,
73        account_username: &str,
74    ) -> Result<stakpak_shared::models::billing::BillingResponse, String> {
75        if let Some(api) = &self.stakpak_api {
76            api.get_billing(account_username).await
77        } else {
78            Err("Billing info not available without Stakpak API key".to_string())
79        }
80    }
81
82    // =========================================================================
83    // Rulebooks
84    // =========================================================================
85
86    async fn list_rulebooks(&self) -> Result<Vec<ListRuleBook>, String> {
87        if let Some(api) = &self.stakpak_api {
88            api.list_rulebooks().await
89        } else {
90            // Try to fetch public rulebooks via unauthenticated request
91            let client = stakpak_shared::tls_client::create_tls_client(
92                stakpak_shared::tls_client::TlsClientConfig::default()
93                    .with_timeout(std::time::Duration::from_secs(30)),
94            )?;
95
96            let url = format!("{}/v1/rules", self.get_stakpak_api_endpoint());
97            let response = client.get(&url).send().await.map_err(|e| e.to_string())?;
98
99            if response.status().is_success() {
100                let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
101                match serde_json::from_value::<ListRulebooksResponse>(value) {
102                    Ok(resp) => Ok(resp.results),
103                    Err(_) => Ok(vec![]),
104                }
105            } else {
106                Ok(vec![])
107            }
108        }
109    }
110
111    async fn get_rulebook_by_uri(&self, uri: &str) -> Result<RuleBook, String> {
112        if let Some(api) = &self.stakpak_api {
113            api.get_rulebook_by_uri(uri).await
114        } else {
115            // Try to fetch public rulebook via unauthenticated request
116            let client = stakpak_shared::tls_client::create_tls_client(
117                stakpak_shared::tls_client::TlsClientConfig::default()
118                    .with_timeout(std::time::Duration::from_secs(30)),
119            )?;
120
121            let encoded_uri = urlencoding::encode(uri);
122            let url = format!(
123                "{}/v1/rules/{}",
124                self.get_stakpak_api_endpoint(),
125                encoded_uri
126            );
127            let response = client.get(&url).send().await.map_err(|e| e.to_string())?;
128
129            if response.status().is_success() {
130                response.json().await.map_err(|e| e.to_string())
131            } else {
132                Err("Rulebook not found".to_string())
133            }
134        }
135    }
136
137    async fn create_rulebook(
138        &self,
139        uri: &str,
140        description: &str,
141        content: &str,
142        tags: Vec<String>,
143        visibility: Option<RuleBookVisibility>,
144    ) -> Result<CreateRuleBookResponse, String> {
145        if let Some(api) = &self.stakpak_api {
146            api.create_rulebook(&CreateRuleBookInput {
147                uri: uri.to_string(),
148                description: description.to_string(),
149                content: content.to_string(),
150                tags,
151                visibility,
152            })
153            .await
154        } else {
155            Err("Creating rulebooks requires Stakpak API key".to_string())
156        }
157    }
158
159    async fn delete_rulebook(&self, uri: &str) -> Result<(), String> {
160        if let Some(api) = &self.stakpak_api {
161            api.delete_rulebook(uri).await
162        } else {
163            Err("Deleting rulebooks requires Stakpak API key".to_string())
164        }
165    }
166
167    // =========================================================================
168    // Agent Sessions
169    // =========================================================================
170
171    async fn list_agent_sessions(&self) -> Result<Vec<AgentSession>, String> {
172        if let Some(api) = &self.stakpak_api {
173            // Use Stakpak API
174            let response = api.list_sessions(&ListSessionsQuery::default()).await?;
175            Ok(response
176                .sessions
177                .into_iter()
178                .map(|s| AgentSession {
179                    id: s.id,
180                    title: s.title,
181                    agent_id: AgentID::PabloV1,
182                    visibility: match s.visibility {
183                        crate::stakpak::SessionVisibility::Public => AgentSessionVisibility::Public,
184                        crate::stakpak::SessionVisibility::Private => {
185                            AgentSessionVisibility::Private
186                        }
187                    },
188                    checkpoints: vec![], // Summary doesn't include full checkpoints
189                    created_at: s.created_at,
190                    updated_at: s.updated_at,
191                })
192                .collect())
193        } else {
194            // Fallback to local DB
195            db::list_sessions(&self.local_db).await
196        }
197    }
198
199    async fn get_agent_session(&self, session_id: Uuid) -> Result<AgentSession, String> {
200        if let Some(api) = &self.stakpak_api {
201            let response = api.get_session(session_id).await?;
202            let s = response.session;
203
204            // Get checkpoints for this session
205            let checkpoints_response = api
206                .list_checkpoints(session_id, &ListCheckpointsQuery::default())
207                .await?;
208
209            Ok(AgentSession {
210                id: s.id,
211                title: s.title,
212                agent_id: AgentID::PabloV1,
213                visibility: match s.visibility {
214                    crate::stakpak::SessionVisibility::Public => AgentSessionVisibility::Public,
215                    crate::stakpak::SessionVisibility::Private => AgentSessionVisibility::Private,
216                },
217                checkpoints: checkpoints_response
218                    .checkpoints
219                    .into_iter()
220                    .enumerate()
221                    .map(|(i, c)| AgentCheckpointListItem {
222                        id: c.id,
223                        status: AgentStatus::Complete,
224                        execution_depth: i,
225                        parent: c.parent_id.map(|id| AgentParentCheckpoint { id }),
226                        created_at: c.created_at,
227                        updated_at: c.updated_at,
228                    })
229                    .collect(),
230                created_at: s.created_at,
231                updated_at: s.updated_at,
232            })
233        } else {
234            db::get_session(&self.local_db, session_id).await
235        }
236    }
237
238    async fn get_agent_session_stats(
239        &self,
240        _session_id: Uuid,
241    ) -> Result<AgentSessionStats, String> {
242        // TODO: Implement session stats via Stakpak API when available
243        Ok(AgentSessionStats::default())
244    }
245
246    async fn get_agent_checkpoint(&self, checkpoint_id: Uuid) -> Result<RunAgentOutput, String> {
247        if let Some(api) = &self.stakpak_api {
248            let response = api.get_checkpoint(checkpoint_id).await?;
249            let c = response.checkpoint;
250
251            // Get session info
252            let session_response = api.get_session(c.session_id).await?;
253            let s = session_response.session;
254
255            Ok(RunAgentOutput {
256                checkpoint: AgentCheckpointListItem {
257                    id: c.id,
258                    status: AgentStatus::Complete,
259                    execution_depth: 0, // We don't track depth in new API
260                    parent: c.parent_id.map(|id| AgentParentCheckpoint { id }),
261                    created_at: c.created_at,
262                    updated_at: c.updated_at,
263                },
264                session: AgentSessionListItem {
265                    id: s.id,
266                    agent_id: AgentID::PabloV1,
267                    visibility: match s.visibility {
268                        crate::stakpak::SessionVisibility::Public => AgentSessionVisibility::Public,
269                        crate::stakpak::SessionVisibility::Private => {
270                            AgentSessionVisibility::Private
271                        }
272                    },
273                    created_at: s.created_at,
274                    updated_at: s.updated_at,
275                },
276                output: AgentOutput::PabloV1 {
277                    messages: c.state.messages,
278                    node_states: serde_json::json!({}),
279                },
280            })
281        } else {
282            db::get_checkpoint(&self.local_db, checkpoint_id).await
283        }
284    }
285
286    async fn get_agent_session_latest_checkpoint(
287        &self,
288        session_id: Uuid,
289    ) -> Result<RunAgentOutput, String> {
290        if let Some(api) = &self.stakpak_api {
291            // Get session with active checkpoint
292            let session_response = api.get_session(session_id).await?;
293            let s = session_response.session;
294
295            if let Some(active_checkpoint) = s.active_checkpoint {
296                Ok(RunAgentOutput {
297                    checkpoint: AgentCheckpointListItem {
298                        id: active_checkpoint.id,
299                        status: AgentStatus::Complete,
300                        execution_depth: 0,
301                        parent: active_checkpoint
302                            .parent_id
303                            .map(|id| AgentParentCheckpoint { id }),
304                        created_at: active_checkpoint.created_at,
305                        updated_at: active_checkpoint.updated_at,
306                    },
307                    session: AgentSessionListItem {
308                        id: s.id,
309                        agent_id: AgentID::PabloV1,
310                        visibility: match s.visibility {
311                            crate::stakpak::SessionVisibility::Public => {
312                                AgentSessionVisibility::Public
313                            }
314                            crate::stakpak::SessionVisibility::Private => {
315                                AgentSessionVisibility::Private
316                            }
317                        },
318                        created_at: s.created_at,
319                        updated_at: s.updated_at,
320                    },
321                    output: AgentOutput::PabloV1 {
322                        messages: active_checkpoint.state.messages,
323                        node_states: serde_json::json!({}),
324                    },
325                })
326            } else {
327                Err("Session has no active checkpoint".to_string())
328            }
329        } else {
330            db::get_latest_checkpoint(&self.local_db, session_id).await
331        }
332    }
333
334    // =========================================================================
335    // Chat Completion
336    // =========================================================================
337
338    async fn chat_completion(
339        &self,
340        model: AgentModel,
341        messages: Vec<ChatMessage>,
342        tools: Option<Vec<Tool>>,
343    ) -> Result<ChatCompletionResponse, String> {
344        let mut ctx = HookContext::new(None, AgentState::new(model, messages, tools));
345
346        // Execute before request hooks
347        self.hook_registry
348            .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
349            .await
350            .map_err(|e| e.to_string())?
351            .ok()?;
352
353        // Initialize or resume session
354        let current_checkpoint = self.initialize_session(&ctx.state.messages).await?;
355        ctx.set_session_id(current_checkpoint.session.id);
356
357        // Run completion
358        let new_message = self.run_agent_completion(&mut ctx, None).await?;
359        ctx.state.append_new_message(new_message.clone());
360
361        // Save checkpoint
362        let result = self
363            .update_session(&current_checkpoint, ctx.state.messages.clone())
364            .await?;
365        let checkpoint_created_at = result.checkpoint.created_at.timestamp() as u64;
366        ctx.set_new_checkpoint_id(result.checkpoint.id);
367
368        // Execute after request hooks
369        self.hook_registry
370            .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
371            .await
372            .map_err(|e| e.to_string())?
373            .ok()?;
374
375        Ok(ChatCompletionResponse {
376            id: ctx.new_checkpoint_id.unwrap().to_string(),
377            object: "chat.completion".to_string(),
378            created: checkpoint_created_at,
379            model: ctx
380                .state
381                .llm_input
382                .as_ref()
383                .map(|llm_input| llm_input.model.clone().to_string())
384                .unwrap_or_default(),
385            choices: vec![ChatCompletionChoice {
386                index: 0,
387                message: ctx.state.messages.last().cloned().unwrap(),
388                logprobs: None,
389                finish_reason: FinishReason::Stop,
390            }],
391            usage: ctx
392                .state
393                .llm_output
394                .as_ref()
395                .map(|u| u.usage.clone())
396                .unwrap_or_default(),
397            system_fingerprint: None,
398            metadata: None,
399        })
400    }
401
402    async fn chat_completion_stream(
403        &self,
404        model: AgentModel,
405        messages: Vec<ChatMessage>,
406        tools: Option<Vec<Tool>>,
407        _headers: Option<HeaderMap>,
408    ) -> Result<
409        (
410            Pin<
411                Box<dyn Stream<Item = Result<ChatCompletionStreamResponse, ApiStreamError>> + Send>,
412            >,
413            Option<String>,
414        ),
415        String,
416    > {
417        let mut ctx = HookContext::new(None, AgentState::new(model, messages, tools));
418
419        // Execute before request hooks
420        self.hook_registry
421            .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
422            .await
423            .map_err(|e| e.to_string())?
424            .ok()?;
425
426        // Initialize session
427        let current_checkpoint = self.initialize_session(&ctx.state.messages).await?;
428        ctx.set_session_id(current_checkpoint.session.id);
429
430        let (tx, mut rx) = mpsc::channel::<Result<StreamMessage, String>>(100);
431
432        // Send initial checkpoint ID
433        let _ = tx
434            .send(Ok(StreamMessage::Delta(GenerationDelta::Content {
435                content: format!(
436                    "\n<checkpoint_id>{}</checkpoint_id>\n",
437                    current_checkpoint.checkpoint.id
438                ),
439            })))
440            .await;
441
442        // Clone what we need for the spawned task
443        let client = self.clone();
444        let mut ctx_clone = ctx.clone();
445
446        // Spawn the completion task with proper shutdown handling
447        // The task checks if the channel is closed before each expensive operation
448        // to support graceful shutdown when the stream consumer is dropped
449        tokio::spawn(async move {
450            // Check if consumer is still listening before starting
451            if tx.is_closed() {
452                return;
453            }
454
455            let result = client
456                .run_agent_completion(&mut ctx_clone, Some(tx.clone()))
457                .await;
458
459            match result {
460                Err(e) => {
461                    let _ = tx.send(Err(e)).await;
462                }
463                Ok(new_message) => {
464                    // Check if consumer is still listening before continuing
465                    if tx.is_closed() {
466                        return;
467                    }
468
469                    ctx_clone.state.append_new_message(new_message.clone());
470                    if tx
471                        .send(Ok(StreamMessage::Ctx(Box::new(ctx_clone.clone()))))
472                        .await
473                        .is_err()
474                    {
475                        // Consumer dropped, exit gracefully
476                        return;
477                    }
478
479                    // Check again before expensive session update
480                    if tx.is_closed() {
481                        return;
482                    }
483
484                    let output = client
485                        .update_session(&current_checkpoint, ctx_clone.state.messages.clone())
486                        .await;
487
488                    match output {
489                        Err(e) => {
490                            let _ = tx.send(Err(e)).await;
491                        }
492                        Ok(output) => {
493                            ctx_clone.set_new_checkpoint_id(output.checkpoint.id);
494                            let _ = tx.send(Ok(StreamMessage::Ctx(Box::new(ctx_clone)))).await;
495                            let _ = tx
496                                .send(Ok(StreamMessage::Delta(GenerationDelta::Content {
497                                    content: format!(
498                                        "\n<checkpoint_id>{}</checkpoint_id>\n",
499                                        output.checkpoint.id
500                                    ),
501                                })))
502                                .await;
503                        }
504                    }
505                }
506            }
507        });
508
509        let hook_registry = self.hook_registry.clone();
510        let stream = async_stream::stream! {
511            while let Some(delta_result) = rx.recv().await {
512                match delta_result {
513                    Ok(delta) => match delta {
514                        StreamMessage::Ctx(updated_ctx) => {
515                            ctx = *updated_ctx;
516                        }
517                        StreamMessage::Delta(delta) => {
518                            yield Ok(ChatCompletionStreamResponse {
519                                id: ctx.request_id.to_string(),
520                                object: "chat.completion.chunk".to_string(),
521                                created: chrono::Utc::now().timestamp() as u64,
522                                model: ctx.state.llm_input.as_ref().map(|llm_input| llm_input.model.clone().to_string()).unwrap_or_default(),
523                                choices: vec![ChatCompletionStreamChoice {
524                                    index: 0,
525                                    delta: delta.into(),
526                                    finish_reason: None,
527                                }],
528                                usage: ctx.state.llm_output.as_ref().map(|u| u.usage.clone()),
529                                metadata: None,
530                            })
531                        }
532                    }
533                    Err(e) => yield Err(ApiStreamError::Unknown(e)),
534                }
535            }
536
537            // Execute after request hooks
538            hook_registry
539                .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
540                .await
541                .map_err(|e| e.to_string())?
542                .ok()?;
543        };
544
545        Ok((Box::pin(stream), None))
546    }
547
548    async fn cancel_stream(&self, request_id: String) -> Result<(), String> {
549        if let Some(api) = &self.stakpak_api {
550            api.cancel_request(&request_id).await
551        } else {
552            // Local mode doesn't support cancellation yet
553            Ok(())
554        }
555    }
556
557    // =========================================================================
558    // Search Docs
559    // =========================================================================
560
561    async fn search_docs(&self, input: &SearchDocsRequest) -> Result<Vec<Content>, String> {
562        if let Some(api) = &self.stakpak_api {
563            api.search_docs(&crate::stakpak::SearchDocsRequest {
564                keywords: input.keywords.clone(),
565                exclude_keywords: input.exclude_keywords.clone(),
566                limit: input.limit,
567            })
568            .await
569        } else {
570            // Fallback to local search service
571            use stakpak_shared::models::integrations::search_service::*;
572
573            let config = SearchServicesOrchestrator::start()
574                .await
575                .map_err(|e| e.to_string())?;
576
577            let api_url = format!("http://localhost:{}", config.api_port);
578            let search_client = SearchClient::new(api_url);
579
580            let search_results = search_client
581                .search_and_scrape(input.keywords.clone(), None)
582                .await
583                .map_err(|e| e.to_string())?;
584
585            if search_results.is_empty() {
586                return Ok(vec![Content::text("No results found".to_string())]);
587            }
588
589            Ok(search_results
590                .into_iter()
591                .map(|result| {
592                    let content = result.content.unwrap_or_default();
593                    Content::text(format!("URL: {}\nContent: {}", result.url, content))
594                })
595                .collect())
596        }
597    }
598
599    // =========================================================================
600    // Memory
601    // =========================================================================
602
603    async fn memorize_session(&self, checkpoint_id: Uuid) -> Result<(), String> {
604        if let Some(api) = &self.stakpak_api {
605            api.memorize_session(checkpoint_id).await
606        } else {
607            // No-op in local mode
608            Ok(())
609        }
610    }
611
612    async fn search_memory(&self, input: &SearchMemoryRequest) -> Result<Vec<Content>, String> {
613        if let Some(api) = &self.stakpak_api {
614            api.search_memory(&crate::stakpak::SearchMemoryRequest {
615                keywords: input.keywords.clone(),
616                start_time: input.start_time,
617                end_time: input.end_time,
618            })
619            .await
620        } else {
621            // Empty results in local mode
622            Ok(vec![])
623        }
624    }
625
626    // =========================================================================
627    // Slack
628    // =========================================================================
629
630    async fn slack_read_messages(
631        &self,
632        input: &SlackReadMessagesRequest,
633    ) -> Result<Vec<Content>, String> {
634        if let Some(api) = &self.stakpak_api {
635            api.slack_read_messages(&crate::stakpak::SlackReadMessagesRequest {
636                channel: input.channel.clone(),
637                limit: input.limit,
638            })
639            .await
640        } else {
641            Err("Slack integration requires Stakpak API key".to_string())
642        }
643    }
644
645    async fn slack_read_replies(
646        &self,
647        input: &SlackReadRepliesRequest,
648    ) -> Result<Vec<Content>, String> {
649        if let Some(api) = &self.stakpak_api {
650            api.slack_read_replies(&crate::stakpak::SlackReadRepliesRequest {
651                channel: input.channel.clone(),
652                ts: input.ts.clone(),
653            })
654            .await
655        } else {
656            Err("Slack integration requires Stakpak API key".to_string())
657        }
658    }
659
660    async fn slack_send_message(
661        &self,
662        input: &SlackSendMessageRequest,
663    ) -> Result<Vec<Content>, String> {
664        if let Some(api) = &self.stakpak_api {
665            api.slack_send_message(&crate::stakpak::SlackSendMessageRequest {
666                channel: input.channel.clone(),
667                mrkdwn_text: input.mrkdwn_text.clone(),
668                thread_ts: input.thread_ts.clone(),
669            })
670            .await
671        } else {
672            Err("Slack integration requires Stakpak API key".to_string())
673        }
674    }
675}
676
677// =============================================================================
678// Helper Methods
679// =============================================================================
680
681const TITLE_GENERATOR_PROMPT: &str =
682    include_str!("../local/prompts/session_title_generator.v1.txt");
683
684impl AgentClient {
685    /// Initialize or resume a session based on messages
686    pub(crate) async fn initialize_session(
687        &self,
688        messages: &[ChatMessage],
689    ) -> Result<RunAgentOutput, String> {
690        if messages.is_empty() {
691            return Err("At least one message is required".to_string());
692        }
693
694        // Check if we have an existing checkpoint ID in messages
695        let checkpoint_id = ChatMessage::last_server_message(messages).and_then(|message| {
696            message
697                .content
698                .as_ref()
699                .and_then(|content| content.extract_checkpoint_id())
700        });
701
702        if let Some(checkpoint_id) = checkpoint_id {
703            // Resume existing session
704            return self.get_agent_checkpoint(checkpoint_id).await;
705        }
706
707        // Create new session
708        // Generate title with fallback - don't fail session creation if title generation fails
709        let title = match self.generate_session_title(messages).await {
710            Ok(title) => title,
711            Err(_) => {
712                // Extract first few words from user message as fallback title
713                messages
714                    .iter()
715                    .find(|m| m.role == Role::User)
716                    .and_then(|m| m.content.as_ref())
717                    .map(|c| {
718                        let text = c.to_string();
719                        text.split_whitespace()
720                            .take(5)
721                            .collect::<Vec<_>>()
722                            .join(" ")
723                    })
724                    .unwrap_or_else(|| "New Session".to_string())
725            }
726        };
727
728        // Get current working directory
729        let cwd = std::env::current_dir()
730            .ok()
731            .map(|p| p.to_string_lossy().to_string());
732
733        if let Some(api) = &self.stakpak_api {
734            // Create session via Stakpak API (includes initial checkpoint)
735            let mut session_request = CreateSessionRequest::new(
736                title,
737                CheckpointState {
738                    messages: messages.to_vec(),
739                },
740            );
741            if let Some(cwd) = cwd {
742                session_request = session_request.with_cwd(cwd);
743            }
744            let response = api.create_session(&session_request).await?;
745
746            Ok(RunAgentOutput {
747                checkpoint: AgentCheckpointListItem {
748                    id: response.checkpoint.id,
749                    status: AgentStatus::Complete,
750                    execution_depth: 0,
751                    parent: response
752                        .checkpoint
753                        .parent_id
754                        .map(|id| AgentParentCheckpoint { id }),
755                    created_at: response.checkpoint.created_at,
756                    updated_at: response.checkpoint.updated_at,
757                },
758                session: AgentSessionListItem {
759                    id: response.session_id,
760                    agent_id: AgentID::PabloV1,
761                    visibility: AgentSessionVisibility::Private,
762                    created_at: response.checkpoint.created_at,
763                    updated_at: response.checkpoint.updated_at,
764                },
765                output: AgentOutput::PabloV1 {
766                    messages: messages.to_vec(),
767                    node_states: serde_json::json!({}),
768                },
769            })
770        } else {
771            // Create locally
772            let now = chrono::Utc::now();
773            let session_id = Uuid::new_v4();
774            let session = AgentSession {
775                id: session_id,
776                title,
777                agent_id: AgentID::PabloV1,
778                visibility: AgentSessionVisibility::Private,
779                created_at: now,
780                updated_at: now,
781                checkpoints: vec![],
782            };
783            db::create_session(&self.local_db, &session).await?;
784
785            let checkpoint_id = Uuid::new_v4();
786            let checkpoint = AgentCheckpointListItem {
787                id: checkpoint_id,
788                status: AgentStatus::Complete,
789                execution_depth: 0,
790                parent: None,
791                created_at: now,
792                updated_at: now,
793            };
794            let initial_state = AgentOutput::PabloV1 {
795                messages: messages.to_vec(),
796                node_states: serde_json::json!({}),
797            };
798            db::create_checkpoint(&self.local_db, session_id, &checkpoint, &initial_state).await?;
799
800            db::get_checkpoint(&self.local_db, checkpoint_id).await
801        }
802    }
803
804    /// Update session with new messages
805    pub(crate) async fn update_session(
806        &self,
807        checkpoint_info: &RunAgentOutput,
808        new_messages: Vec<ChatMessage>,
809    ) -> Result<RunAgentOutput, String> {
810        if let Some(api) = &self.stakpak_api {
811            // Add checkpoint via Stakpak API
812            let checkpoint_request = CreateCheckpointRequest::new(CheckpointState {
813                messages: new_messages.clone(),
814            })
815            .with_parent(checkpoint_info.checkpoint.id);
816
817            let response = api
818                .create_checkpoint(checkpoint_info.session.id, &checkpoint_request)
819                .await?;
820
821            Ok(RunAgentOutput {
822                checkpoint: AgentCheckpointListItem {
823                    id: response.checkpoint.id,
824                    status: AgentStatus::Complete,
825                    execution_depth: checkpoint_info.checkpoint.execution_depth + 1,
826                    parent: Some(AgentParentCheckpoint {
827                        id: checkpoint_info.checkpoint.id,
828                    }),
829                    created_at: response.checkpoint.created_at,
830                    updated_at: response.checkpoint.updated_at,
831                },
832                session: checkpoint_info.session.clone(),
833                output: AgentOutput::PabloV1 {
834                    messages: new_messages,
835                    node_states: serde_json::json!({}),
836                },
837            })
838        } else {
839            // Create checkpoint locally
840            let now = chrono::Utc::now();
841            let complete_checkpoint = AgentCheckpointListItem {
842                id: Uuid::new_v4(),
843                status: AgentStatus::Complete,
844                execution_depth: checkpoint_info.checkpoint.execution_depth + 1,
845                parent: Some(AgentParentCheckpoint {
846                    id: checkpoint_info.checkpoint.id,
847                }),
848                created_at: now,
849                updated_at: now,
850            };
851
852            let new_state = AgentOutput::PabloV1 {
853                messages: new_messages.clone(),
854                node_states: serde_json::json!({}),
855            };
856
857            db::create_checkpoint(
858                &self.local_db,
859                checkpoint_info.session.id,
860                &complete_checkpoint,
861                &new_state,
862            )
863            .await?;
864
865            Ok(RunAgentOutput {
866                checkpoint: complete_checkpoint,
867                session: checkpoint_info.session.clone(),
868                output: new_state,
869            })
870        }
871    }
872
873    /// Run agent completion (inference)
874    pub(crate) async fn run_agent_completion(
875        &self,
876        ctx: &mut HookContext<AgentState>,
877        stream_channel_tx: Option<mpsc::Sender<Result<StreamMessage, String>>>,
878    ) -> Result<ChatMessage, String> {
879        // Execute before inference hooks
880        self.hook_registry
881            .execute_hooks(ctx, &LifecycleEvent::BeforeInference)
882            .await
883            .map_err(|e| e.to_string())?
884            .ok()?;
885
886        let mut input = if let Some(llm_input) = ctx.state.llm_input.clone() {
887            llm_input
888        } else {
889            return Err(
890                "LLM input not found, make sure to register a context hook before inference"
891                    .to_string(),
892            );
893        };
894
895        // Inject session_id header if available
896        if let Some(session_id) = ctx.session_id {
897            let headers = input
898                .headers
899                .get_or_insert_with(std::collections::HashMap::new);
900            headers.insert("X-Session-Id".to_string(), session_id.to_string());
901        }
902
903        let (response_message, usage) = if let Some(tx) = stream_channel_tx {
904            // Streaming mode
905            let (internal_tx, mut internal_rx) = mpsc::channel::<GenerationDelta>(100);
906            let stream_input = LLMStreamInput {
907                model: input.model,
908                messages: input.messages,
909                max_tokens: input.max_tokens,
910                tools: input.tools,
911                stream_channel_tx: internal_tx,
912                provider_options: input.provider_options,
913                headers: input.headers,
914            };
915
916            let stakai = self.stakai.clone();
917            let chat_future = async move {
918                stakai
919                    .chat_stream(stream_input)
920                    .await
921                    .map_err(|e| e.to_string())
922            };
923
924            let receive_future = async move {
925                while let Some(delta) = internal_rx.recv().await {
926                    if tx.send(Ok(StreamMessage::Delta(delta))).await.is_err() {
927                        break;
928                    }
929                }
930            };
931
932            let (chat_result, _) = tokio::join!(chat_future, receive_future);
933            let response = chat_result?;
934            (response.choices[0].message.clone(), response.usage)
935        } else {
936            // Non-streaming mode
937            let response = self.stakai.chat(input).await.map_err(|e| e.to_string())?;
938            (response.choices[0].message.clone(), response.usage)
939        };
940
941        ctx.state.set_llm_output(response_message, usage);
942
943        // Execute after inference hooks
944        self.hook_registry
945            .execute_hooks(ctx, &LifecycleEvent::AfterInference)
946            .await
947            .map_err(|e| e.to_string())?
948            .ok()?;
949
950        let llm_output = ctx
951            .state
952            .llm_output
953            .as_ref()
954            .ok_or_else(|| "LLM output is missing from state".to_string())?;
955
956        Ok(ChatMessage::from(llm_output))
957    }
958
959    /// Generate a title for a new session
960    async fn generate_session_title(&self, messages: &[ChatMessage]) -> Result<String, String> {
961        let llm_model = if let Some(eco_model) = &self.model_options.eco_model {
962            eco_model.clone()
963        } else {
964            // Try to find a suitable model
965            LLMModel::Anthropic(AnthropicModel::Claude45Haiku)
966        };
967
968        // If Stakpak is available, route through it
969        let model = if self.has_stakpak() {
970            // Get properly formatted model string with provider prefix (e.g., "anthropic/claude-haiku-4-5")
971            let model_str = get_stakai_model_string(&llm_model);
972            // Extract display name from the last segment for UI
973            let display_name = model_str
974                .rsplit('/')
975                .next()
976                .unwrap_or(&model_str)
977                .to_string();
978            LLMModel::Custom {
979                provider: "stakpak".to_string(),
980                model: model_str,
981                name: Some(display_name),
982            }
983        } else {
984            llm_model
985        };
986
987        let llm_messages = vec![
988            LLMMessage {
989                role: Role::System.to_string(),
990                content: LLMMessageContent::String(TITLE_GENERATOR_PROMPT.to_string()),
991            },
992            LLMMessage {
993                role: Role::User.to_string(),
994                content: LLMMessageContent::String(
995                    messages
996                        .iter()
997                        .map(|msg| {
998                            msg.content
999                                .as_ref()
1000                                .unwrap_or(&MessageContent::String("".to_string()))
1001                                .to_string()
1002                        })
1003                        .collect(),
1004                ),
1005            },
1006        ];
1007
1008        let input = LLMInput {
1009            model,
1010            messages: llm_messages,
1011            max_tokens: 100,
1012            tools: None,
1013            provider_options: None,
1014            headers: None,
1015        };
1016
1017        let response = self.stakai.chat(input).await.map_err(|e| e.to_string())?;
1018
1019        Ok(response.choices[0].message.content.to_string())
1020    }
1021}