Skip to main content

stakpak_api/client/
provider.rs

1//! AgentProvider trait implementation for AgentClient
2//!
3//! Implements the unified provider interface with:
4//! - Stakpak-first routing when API key is present
5//! - Local fallback when Stakpak is unavailable
6//! - Hook registry integration for lifecycle events
7
8use crate::AgentProvider;
9use crate::models::*;
10use crate::storage::{
11    CreateCheckpointRequest as StorageCreateCheckpointRequest,
12    CreateSessionRequest as StorageCreateSessionRequest,
13    UpdateSessionRequest as StorageUpdateSessionRequest,
14};
15use async_trait::async_trait;
16use futures_util::Stream;
17use reqwest::header::HeaderMap;
18use rmcp::model::Content;
19use stakai::Model;
20use stakpak_shared::hooks::{HookContext, LifecycleEvent};
21use stakpak_shared::models::integrations::openai::{
22    ChatCompletionChoice, ChatCompletionResponse, ChatCompletionStreamChoice,
23    ChatCompletionStreamResponse, ChatMessage, FinishReason, MessageContent, Role, Tool,
24};
25use stakpak_shared::models::llm::{
26    GenerationDelta, LLMInput, LLMMessage, LLMMessageContent, LLMStreamInput,
27};
28use std::pin::Pin;
29use tokio::sync::mpsc;
30use uuid::Uuid;
31
32/// Lightweight session info returned by initialize_session / save_checkpoint
33#[derive(Debug, Clone)]
34pub(crate) struct SessionInfo {
35    session_id: Uuid,
36    checkpoint_id: Uuid,
37    checkpoint_created_at: chrono::DateTime<chrono::Utc>,
38}
39
40use super::AgentClient;
41
42// =============================================================================
43// Internal Message Types
44// =============================================================================
45
46#[derive(Debug)]
47pub(crate) enum StreamMessage {
48    Delta(GenerationDelta),
49    Ctx(Box<HookContext<AgentState>>),
50}
51
52// =============================================================================
53// AgentProvider Implementation
54// =============================================================================
55
56#[async_trait]
57impl AgentProvider for AgentClient {
58    // =========================================================================
59    // Account
60    // =========================================================================
61
62    async fn get_my_account(&self) -> Result<GetMyAccountResponse, String> {
63        if let Some(api) = &self.stakpak_api {
64            api.get_account().await
65        } else {
66            // Local stub
67            Ok(GetMyAccountResponse {
68                username: "local".to_string(),
69                id: "local".to_string(),
70                first_name: "local".to_string(),
71                last_name: "local".to_string(),
72                email: "local@stakpak.dev".to_string(),
73                scope: None,
74            })
75        }
76    }
77
78    async fn get_billing_info(
79        &self,
80        account_username: &str,
81    ) -> Result<stakpak_shared::models::billing::BillingResponse, String> {
82        if let Some(api) = &self.stakpak_api {
83            api.get_billing(account_username).await
84        } else {
85            Err("Billing info not available without Stakpak API key".to_string())
86        }
87    }
88
89    // =========================================================================
90    // Rulebooks
91    // =========================================================================
92
93    async fn list_rulebooks(&self) -> Result<Vec<ListRuleBook>, String> {
94        if let Some(api) = &self.stakpak_api {
95            api.list_rulebooks().await
96        } else {
97            // Try to fetch public rulebooks via unauthenticated request
98            let client = stakpak_shared::tls_client::create_tls_client(
99                stakpak_shared::tls_client::TlsClientConfig::default()
100                    .with_timeout(std::time::Duration::from_secs(30)),
101            )?;
102
103            let url = format!("{}/v1/rules", self.get_stakpak_api_endpoint());
104            let response = client.get(&url).send().await.map_err(|e| e.to_string())?;
105
106            if response.status().is_success() {
107                let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
108                match serde_json::from_value::<ListRulebooksResponse>(value) {
109                    Ok(resp) => Ok(resp.results),
110                    Err(_) => Ok(vec![]),
111                }
112            } else {
113                Ok(vec![])
114            }
115        }
116    }
117
118    async fn get_rulebook_by_uri(&self, uri: &str) -> Result<RuleBook, String> {
119        if let Some(api) = &self.stakpak_api {
120            api.get_rulebook_by_uri(uri).await
121        } else {
122            // Try to fetch public rulebook via unauthenticated request
123            let client = stakpak_shared::tls_client::create_tls_client(
124                stakpak_shared::tls_client::TlsClientConfig::default()
125                    .with_timeout(std::time::Duration::from_secs(30)),
126            )?;
127
128            let encoded_uri = urlencoding::encode(uri);
129            let url = format!(
130                "{}/v1/rules/{}",
131                self.get_stakpak_api_endpoint(),
132                encoded_uri
133            );
134            let response = client.get(&url).send().await.map_err(|e| e.to_string())?;
135
136            if response.status().is_success() {
137                response.json().await.map_err(|e| e.to_string())
138            } else {
139                Err("Rulebook not found".to_string())
140            }
141        }
142    }
143
144    async fn create_rulebook(
145        &self,
146        uri: &str,
147        description: &str,
148        content: &str,
149        tags: Vec<String>,
150        visibility: Option<RuleBookVisibility>,
151    ) -> Result<CreateRuleBookResponse, String> {
152        if let Some(api) = &self.stakpak_api {
153            api.create_rulebook(&CreateRuleBookInput {
154                uri: uri.to_string(),
155                description: description.to_string(),
156                content: content.to_string(),
157                tags,
158                visibility,
159            })
160            .await
161        } else {
162            Err("Creating rulebooks requires Stakpak API key".to_string())
163        }
164    }
165
166    async fn delete_rulebook(&self, uri: &str) -> Result<(), String> {
167        if let Some(api) = &self.stakpak_api {
168            api.delete_rulebook(uri).await
169        } else {
170            Err("Deleting rulebooks requires Stakpak API key".to_string())
171        }
172    }
173
174    // =========================================================================
175    // Chat Completion
176    // =========================================================================
177
178    async fn chat_completion(
179        &self,
180        model: Model,
181        messages: Vec<ChatMessage>,
182        tools: Option<Vec<Tool>>,
183        session_id: Option<Uuid>,
184        metadata: Option<serde_json::Value>,
185    ) -> Result<ChatCompletionResponse, String> {
186        let mut ctx = HookContext::new(session_id, AgentState::new(model, messages, tools));
187        ctx.state.metadata = metadata;
188
189        // Execute before request hooks
190        self.hook_registry
191            .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
192            .await
193            .map_err(|e| e.to_string())?
194            .ok()?;
195
196        // Initialize or resume session
197        let current_session = self.initialize_session(&ctx).await?;
198        ctx.set_session_id(current_session.session_id);
199
200        // Run completion
201        let new_message = self.run_agent_completion(&mut ctx, None).await?;
202        ctx.state.append_new_message(new_message.clone());
203
204        // Save checkpoint
205        let result = self
206            .save_checkpoint(
207                &current_session,
208                ctx.state.messages.clone(),
209                ctx.state.metadata.clone(),
210            )
211            .await?;
212        let checkpoint_created_at = result.checkpoint_created_at.timestamp() as u64;
213        ctx.set_new_checkpoint_id(result.checkpoint_id);
214
215        // Execute after request hooks
216        self.hook_registry
217            .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
218            .await
219            .map_err(|e| e.to_string())?
220            .ok()?;
221
222        let mut meta = serde_json::Map::new();
223        if let Some(session_id) = ctx.session_id {
224            meta.insert(
225                "session_id".to_string(),
226                serde_json::Value::String(session_id.to_string()),
227            );
228        }
229        if let Some(checkpoint_id) = ctx.new_checkpoint_id {
230            meta.insert(
231                "checkpoint_id".to_string(),
232                serde_json::Value::String(checkpoint_id.to_string()),
233            );
234        }
235
236        Ok(ChatCompletionResponse {
237            id: ctx.new_checkpoint_id.unwrap().to_string(),
238            object: "chat.completion".to_string(),
239            created: checkpoint_created_at,
240            model: ctx
241                .state
242                .llm_input
243                .as_ref()
244                .map(|llm_input| llm_input.model.id.clone())
245                .unwrap_or_default(),
246            choices: vec![ChatCompletionChoice {
247                index: 0,
248                message: ctx.state.messages.last().cloned().unwrap(),
249                logprobs: None,
250                finish_reason: FinishReason::Stop,
251            }],
252            usage: ctx
253                .state
254                .llm_output
255                .as_ref()
256                .map(|u| u.usage.clone())
257                .unwrap_or_default(),
258            system_fingerprint: None,
259            metadata: if meta.is_empty() {
260                None
261            } else {
262                Some(serde_json::Value::Object(meta))
263            },
264        })
265    }
266
267    async fn chat_completion_stream(
268        &self,
269        model: Model,
270        messages: Vec<ChatMessage>,
271        tools: Option<Vec<Tool>>,
272        _headers: Option<HeaderMap>,
273        session_id: Option<Uuid>,
274        metadata: Option<serde_json::Value>,
275    ) -> Result<
276        (
277            Pin<
278                Box<dyn Stream<Item = Result<ChatCompletionStreamResponse, ApiStreamError>> + Send>,
279            >,
280            Option<String>,
281        ),
282        String,
283    > {
284        let mut ctx = HookContext::new(session_id, AgentState::new(model, messages, tools));
285        ctx.state.metadata = metadata;
286
287        // Execute before request hooks
288        self.hook_registry
289            .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
290            .await
291            .map_err(|e| e.to_string())?
292            .ok()?;
293
294        // Initialize session
295        let current_session = self.initialize_session(&ctx).await?;
296        ctx.set_session_id(current_session.session_id);
297
298        let (tx, mut rx) = mpsc::channel::<Result<StreamMessage, String>>(100);
299
300        // Clone what we need for the spawned task
301        let client = self.clone();
302        let mut ctx_clone = ctx.clone();
303
304        // Spawn the completion task with proper shutdown handling
305        // The task checks if the channel is closed before each expensive operation
306        // to support graceful shutdown when the stream consumer is dropped
307        tokio::spawn(async move {
308            // Check if consumer is still listening before starting
309            if tx.is_closed() {
310                return;
311            }
312
313            let result = client
314                .run_agent_completion(&mut ctx_clone, Some(tx.clone()))
315                .await;
316
317            match result {
318                Err(e) => {
319                    let _ = tx.send(Err(e)).await;
320                }
321                Ok(new_message) => {
322                    // Check if consumer is still listening before continuing
323                    if tx.is_closed() {
324                        return;
325                    }
326
327                    ctx_clone.state.append_new_message(new_message.clone());
328                    if tx
329                        .send(Ok(StreamMessage::Ctx(Box::new(ctx_clone.clone()))))
330                        .await
331                        .is_err()
332                    {
333                        // Consumer dropped, exit gracefully
334                        return;
335                    }
336
337                    // Check again before expensive session update
338                    if tx.is_closed() {
339                        return;
340                    }
341
342                    let result = client
343                        .save_checkpoint(
344                            &current_session,
345                            ctx_clone.state.messages.clone(),
346                            ctx_clone.state.metadata.clone(),
347                        )
348                        .await;
349
350                    match result {
351                        Err(e) => {
352                            let _ = tx.send(Err(e)).await;
353                        }
354                        Ok(updated) => {
355                            ctx_clone.set_new_checkpoint_id(updated.checkpoint_id);
356                            let _ = tx.send(Ok(StreamMessage::Ctx(Box::new(ctx_clone)))).await;
357                        }
358                    }
359                }
360            }
361        });
362
363        let hook_registry = self.hook_registry.clone();
364        let stream = async_stream::stream! {
365            while let Some(delta_result) = rx.recv().await {
366                match delta_result {
367                    Ok(delta) => match delta {
368                        StreamMessage::Ctx(updated_ctx) => {
369                            ctx = *updated_ctx;
370                            // Emit session metadata so callers can track session_id
371                            if let Some(session_id) = ctx.session_id {
372                                let mut meta = serde_json::Map::new();
373                                meta.insert("session_id".to_string(), serde_json::Value::String(session_id.to_string()));
374                                if let Some(checkpoint_id) = ctx.new_checkpoint_id {
375                                    meta.insert("checkpoint_id".to_string(), serde_json::Value::String(checkpoint_id.to_string()));
376                                }
377                                yield Ok(ChatCompletionStreamResponse {
378                                    id: ctx.request_id.to_string(),
379                                    object: "chat.completion.chunk".to_string(),
380                                    created: chrono::Utc::now().timestamp() as u64,
381                                    model: String::new(),
382                                    choices: vec![],
383                                    usage: None,
384                                    metadata: Some(serde_json::Value::Object(meta)),
385                                });
386                            }
387                        }
388                        StreamMessage::Delta(delta) => {
389                            // Extract usage from Usage delta variant
390                            let usage = if let GenerationDelta::Usage { usage } = &delta {
391                                Some(usage.clone())
392                            } else {
393                                None
394                            };
395
396                            yield Ok(ChatCompletionStreamResponse {
397                                id: ctx.request_id.to_string(),
398                                object: "chat.completion.chunk".to_string(),
399                                created: chrono::Utc::now().timestamp() as u64,
400                                model: ctx.state.llm_input.as_ref().map(|llm_input| llm_input.model.clone().to_string()).unwrap_or_default(),
401                                choices: vec![ChatCompletionStreamChoice {
402                                    index: 0,
403                                    delta: delta.into(),
404                                    finish_reason: None,
405                                }],
406                                usage,
407                                metadata: None,
408                            })
409                        }
410                    }
411                    Err(e) => yield Err(ApiStreamError::Unknown(e)),
412                }
413            }
414
415            // Execute after request hooks
416            hook_registry
417                .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
418                .await
419                .map_err(|e| e.to_string())?
420                .ok()?;
421        };
422
423        Ok((Box::pin(stream), None))
424    }
425
426    async fn cancel_stream(&self, request_id: String) -> Result<(), String> {
427        if let Some(api) = &self.stakpak_api {
428            api.cancel_request(&request_id).await
429        } else {
430            // Local mode doesn't support cancellation yet
431            Ok(())
432        }
433    }
434
435    // =========================================================================
436    // Search Docs
437    // =========================================================================
438
439    async fn search_docs(&self, input: &SearchDocsRequest) -> Result<Vec<Content>, String> {
440        if let Some(api) = &self.stakpak_api {
441            api.search_docs(&crate::stakpak::SearchDocsRequest {
442                keywords: input.keywords.clone(),
443                exclude_keywords: input.exclude_keywords.clone(),
444                limit: input.limit,
445            })
446            .await
447        } else {
448            // Fallback to local search service
449            use stakpak_shared::models::integrations::search_service::*;
450
451            let config = SearchServicesOrchestrator::start()
452                .await
453                .map_err(|e| e.to_string())?;
454
455            let api_url = format!("http://localhost:{}", config.api_port);
456            let search_client = SearchClient::new(api_url);
457
458            let search_results = search_client
459                .search_and_scrape(input.keywords.clone(), None)
460                .await
461                .map_err(|e| e.to_string())?;
462
463            if search_results.is_empty() {
464                return Ok(vec![Content::text("No results found".to_string())]);
465            }
466
467            Ok(search_results
468                .into_iter()
469                .map(|result| {
470                    let content = result.content.unwrap_or_default();
471                    Content::text(format!("URL: {}\nContent: {}", result.url, content))
472                })
473                .collect())
474        }
475    }
476
477    // =========================================================================
478    // Memory
479    // =========================================================================
480
481    async fn memorize_session(&self, checkpoint_id: Uuid) -> Result<(), String> {
482        if let Some(api) = &self.stakpak_api {
483            api.memorize_session(checkpoint_id).await
484        } else {
485            // No-op in local mode
486            Ok(())
487        }
488    }
489
490    async fn search_memory(&self, input: &SearchMemoryRequest) -> Result<Vec<Content>, String> {
491        if let Some(api) = &self.stakpak_api {
492            api.search_memory(&crate::stakpak::SearchMemoryRequest {
493                keywords: input.keywords.clone(),
494                start_time: input.start_time,
495                end_time: input.end_time,
496            })
497            .await
498        } else {
499            // Empty results in local mode
500            Ok(vec![])
501        }
502    }
503
504    // =========================================================================
505    // Slack
506    // =========================================================================
507
508    async fn slack_read_messages(
509        &self,
510        input: &SlackReadMessagesRequest,
511    ) -> Result<Vec<Content>, String> {
512        if let Some(api) = &self.stakpak_api {
513            api.slack_read_messages(&crate::stakpak::SlackReadMessagesRequest {
514                channel: input.channel.clone(),
515                limit: input.limit,
516            })
517            .await
518        } else {
519            Err("Slack integration requires Stakpak API key".to_string())
520        }
521    }
522
523    async fn slack_read_replies(
524        &self,
525        input: &SlackReadRepliesRequest,
526    ) -> Result<Vec<Content>, String> {
527        if let Some(api) = &self.stakpak_api {
528            api.slack_read_replies(&crate::stakpak::SlackReadRepliesRequest {
529                channel: input.channel.clone(),
530                ts: input.ts.clone(),
531            })
532            .await
533        } else {
534            Err("Slack integration requires Stakpak API key".to_string())
535        }
536    }
537
538    async fn slack_send_message(
539        &self,
540        input: &SlackSendMessageRequest,
541    ) -> Result<Vec<Content>, String> {
542        if let Some(api) = &self.stakpak_api {
543            api.slack_send_message(&crate::stakpak::SlackSendMessageRequest {
544                channel: input.channel.clone(),
545                markdown_text: input.markdown_text.clone(),
546                thread_ts: input.thread_ts.clone(),
547            })
548            .await
549        } else {
550            Err("Slack integration requires Stakpak API key".to_string())
551        }
552    }
553
554    // =========================================================================
555    // Models
556    // =========================================================================
557
558    async fn list_models(&self) -> Vec<stakai::Model> {
559        // Use the provider registry which only contains providers with configured API keys.
560        // This ensures we only list models for providers the user actually has access to.
561        // Aggregate per provider so one failing provider does not hide all others.
562        let registry = self.stakai.registry();
563        let mut all_models = Vec::new();
564
565        for provider_id in registry.list_providers() {
566            if let Ok(mut models) = registry.models_for_provider(&provider_id).await {
567                all_models.append(&mut models);
568            }
569        }
570
571        sort_models_by_recency(&mut all_models);
572        all_models
573    }
574}
575
576/// Sort models by release_date descending (newest first)
577fn sort_models_by_recency(models: &mut [stakai::Model]) {
578    models.sort_by(|a, b| {
579        match (&b.release_date, &a.release_date) {
580            (Some(b_date), Some(a_date)) => b_date.cmp(a_date),
581            (Some(_), None) => std::cmp::Ordering::Less,
582            (None, Some(_)) => std::cmp::Ordering::Greater,
583            (None, None) => b.id.cmp(&a.id), // Fallback to ID descending
584        }
585    });
586}
587
588// =============================================================================
589// SessionStorage implementation (delegates to inner session_storage)
590// =============================================================================
591
592#[async_trait]
593impl crate::storage::SessionStorage for super::AgentClient {
594    async fn list_sessions(
595        &self,
596        query: &crate::storage::ListSessionsQuery,
597    ) -> Result<crate::storage::ListSessionsResult, crate::storage::StorageError> {
598        self.session_storage.list_sessions(query).await
599    }
600
601    async fn get_session(
602        &self,
603        session_id: Uuid,
604    ) -> Result<crate::storage::Session, crate::storage::StorageError> {
605        self.session_storage.get_session(session_id).await
606    }
607
608    async fn create_session(
609        &self,
610        request: &crate::storage::CreateSessionRequest,
611    ) -> Result<crate::storage::CreateSessionResult, crate::storage::StorageError> {
612        self.session_storage.create_session(request).await
613    }
614
615    async fn update_session(
616        &self,
617        session_id: Uuid,
618        request: &crate::storage::UpdateSessionRequest,
619    ) -> Result<crate::storage::Session, crate::storage::StorageError> {
620        self.session_storage
621            .update_session(session_id, request)
622            .await
623    }
624
625    async fn delete_session(&self, session_id: Uuid) -> Result<(), crate::storage::StorageError> {
626        self.session_storage.delete_session(session_id).await
627    }
628
629    async fn list_checkpoints(
630        &self,
631        session_id: Uuid,
632        query: &crate::storage::ListCheckpointsQuery,
633    ) -> Result<crate::storage::ListCheckpointsResult, crate::storage::StorageError> {
634        self.session_storage
635            .list_checkpoints(session_id, query)
636            .await
637    }
638
639    async fn get_checkpoint(
640        &self,
641        checkpoint_id: Uuid,
642    ) -> Result<crate::storage::Checkpoint, crate::storage::StorageError> {
643        self.session_storage.get_checkpoint(checkpoint_id).await
644    }
645
646    async fn create_checkpoint(
647        &self,
648        session_id: Uuid,
649        request: &crate::storage::CreateCheckpointRequest,
650    ) -> Result<crate::storage::Checkpoint, crate::storage::StorageError> {
651        self.session_storage
652            .create_checkpoint(session_id, request)
653            .await
654    }
655
656    async fn get_active_checkpoint(
657        &self,
658        session_id: Uuid,
659    ) -> Result<crate::storage::Checkpoint, crate::storage::StorageError> {
660        self.session_storage.get_active_checkpoint(session_id).await
661    }
662
663    async fn get_session_stats(
664        &self,
665        session_id: Uuid,
666    ) -> Result<crate::storage::SessionStats, crate::storage::StorageError> {
667        self.session_storage.get_session_stats(session_id).await
668    }
669}
670
671// =============================================================================
672// Helper Methods
673// =============================================================================
674
675const TITLE_GENERATOR_PROMPT: &str =
676    include_str!("../local/prompts/session_title_generator.v1.txt");
677
678impl AgentClient {
679    /// Initialize or resume a session based on context
680    ///
681    /// If `ctx.session_id` is set, we resume that session directly.
682    /// Otherwise, we create a new session.
683    pub(crate) async fn initialize_session(
684        &self,
685        ctx: &HookContext<AgentState>,
686    ) -> Result<SessionInfo, String> {
687        let messages = &ctx.state.messages;
688
689        if messages.is_empty() {
690            return Err("At least one message is required".to_string());
691        }
692
693        // If session_id is set in context, resume that session directly
694        if let Some(session_id) = ctx.session_id {
695            let session = self
696                .session_storage
697                .get_session(session_id)
698                .await
699                .map_err(|e| e.to_string())?;
700
701            let checkpoint = session
702                .active_checkpoint
703                .ok_or_else(|| format!("Session {} has no active checkpoint", session_id))?;
704
705            // If the session still has the default title, generate a better one in the background.
706            if session.title.trim().is_empty() || session.title == "New Session" {
707                let client = self.clone();
708                let messages_for_title = messages.to_vec();
709                let session_id = session.id;
710                let existing_title = session.title.clone();
711                tokio::spawn(async move {
712                    if let Ok(title) = client.generate_session_title(&messages_for_title).await {
713                        let trimmed = title.trim();
714                        if !trimmed.is_empty() && trimmed != existing_title {
715                            let request =
716                                StorageUpdateSessionRequest::new().with_title(trimmed.to_string());
717                            let _ = client
718                                .session_storage
719                                .update_session(session_id, &request)
720                                .await;
721                        }
722                    }
723                });
724            }
725
726            return Ok(SessionInfo {
727                session_id: session.id,
728                checkpoint_id: checkpoint.id,
729                checkpoint_created_at: checkpoint.created_at,
730            });
731        }
732
733        // Create new session with a fast local title.
734        let fallback_title = Self::fallback_session_title(messages);
735
736        // Get current working directory
737        let cwd = std::env::current_dir()
738            .ok()
739            .map(|p| p.to_string_lossy().to_string());
740
741        // Create session via storage trait
742        let mut session_request =
743            StorageCreateSessionRequest::new(fallback_title.clone(), messages.to_vec());
744        if let Some(cwd) = cwd {
745            session_request = session_request.with_cwd(cwd);
746        }
747
748        let result = self
749            .session_storage
750            .create_session(&session_request)
751            .await
752            .map_err(|e| e.to_string())?;
753
754        // Generate a better title asynchronously and update the session when ready.
755        let client = self.clone();
756        let messages_for_title = messages.to_vec();
757        let session_id = result.session_id;
758        tokio::spawn(async move {
759            if let Ok(title) = client.generate_session_title(&messages_for_title).await {
760                let trimmed = title.trim();
761                if !trimmed.is_empty() && trimmed != fallback_title {
762                    let request =
763                        StorageUpdateSessionRequest::new().with_title(trimmed.to_string());
764                    let _ = client
765                        .session_storage
766                        .update_session(session_id, &request)
767                        .await;
768                }
769            }
770        });
771
772        Ok(SessionInfo {
773            session_id: result.session_id,
774            checkpoint_id: result.checkpoint.id,
775            checkpoint_created_at: result.checkpoint.created_at,
776        })
777    }
778
779    fn fallback_session_title(messages: &[ChatMessage]) -> String {
780        messages
781            .iter()
782            .find(|m| m.role == Role::User)
783            .and_then(|m| m.content.as_ref())
784            .map(|c| {
785                let text = c.to_string();
786                text.split_whitespace()
787                    .take(5)
788                    .collect::<Vec<_>>()
789                    .join(" ")
790            })
791            .unwrap_or_else(|| "New Session".to_string())
792    }
793
794    /// Save a new checkpoint for the current session
795    pub(crate) async fn save_checkpoint(
796        &self,
797        current: &SessionInfo,
798        messages: Vec<ChatMessage>,
799        metadata: Option<serde_json::Value>,
800    ) -> Result<SessionInfo, String> {
801        let mut checkpoint_request =
802            StorageCreateCheckpointRequest::new(messages).with_parent(current.checkpoint_id);
803
804        if let Some(meta) = metadata {
805            checkpoint_request = checkpoint_request.with_metadata(meta);
806        }
807
808        let checkpoint = self
809            .session_storage
810            .create_checkpoint(current.session_id, &checkpoint_request)
811            .await
812            .map_err(|e| e.to_string())?;
813
814        Ok(SessionInfo {
815            session_id: current.session_id,
816            checkpoint_id: checkpoint.id,
817            checkpoint_created_at: checkpoint.created_at,
818        })
819    }
820
821    /// Run agent completion (inference)
822    pub(crate) async fn run_agent_completion(
823        &self,
824        ctx: &mut HookContext<AgentState>,
825        stream_channel_tx: Option<mpsc::Sender<Result<StreamMessage, String>>>,
826    ) -> Result<ChatMessage, String> {
827        // Execute before inference hooks
828        self.hook_registry
829            .execute_hooks(ctx, &LifecycleEvent::BeforeInference)
830            .await
831            .map_err(|e| e.to_string())?
832            .ok()?;
833
834        let mut input = if let Some(llm_input) = ctx.state.llm_input.clone() {
835            llm_input
836        } else {
837            return Err(
838                "LLM input not found, make sure to register a context hook before inference"
839                    .to_string(),
840            );
841        };
842
843        // Inject session_id header if available
844        if let Some(session_id) = ctx.session_id {
845            let headers = input
846                .headers
847                .get_or_insert_with(std::collections::HashMap::new);
848            headers.insert("X-Session-Id".to_string(), session_id.to_string());
849        }
850
851        let (response_message, usage) = if let Some(tx) = stream_channel_tx {
852            // Streaming mode
853            let (internal_tx, mut internal_rx) = mpsc::channel::<GenerationDelta>(100);
854            let stream_input = LLMStreamInput {
855                model: input.model,
856                messages: input.messages,
857                max_tokens: input.max_tokens,
858                tools: input.tools,
859                stream_channel_tx: internal_tx,
860                provider_options: input.provider_options,
861                headers: input.headers,
862            };
863
864            let stakai = self.stakai.clone();
865            let chat_future = async move {
866                stakai
867                    .chat_stream(stream_input)
868                    .await
869                    .map_err(|e| e.to_string())
870            };
871
872            let receive_future = async move {
873                while let Some(delta) = internal_rx.recv().await {
874                    if tx.send(Ok(StreamMessage::Delta(delta))).await.is_err() {
875                        break;
876                    }
877                }
878            };
879
880            let (chat_result, _) = tokio::join!(chat_future, receive_future);
881            let response = chat_result?;
882            (response.choices[0].message.clone(), response.usage)
883        } else {
884            // Non-streaming mode
885            let response = self.stakai.chat(input).await.map_err(|e| e.to_string())?;
886            (response.choices[0].message.clone(), response.usage)
887        };
888
889        ctx.state.set_llm_output(response_message, usage);
890
891        // Execute after inference hooks
892        self.hook_registry
893            .execute_hooks(ctx, &LifecycleEvent::AfterInference)
894            .await
895            .map_err(|e| e.to_string())?
896            .ok()?;
897
898        let llm_output = ctx
899            .state
900            .llm_output
901            .as_ref()
902            .ok_or_else(|| "LLM output is missing from state".to_string())?;
903
904        Ok(ChatMessage::from(llm_output))
905    }
906
907    /// Generate a title for a new session
908    async fn generate_session_title(&self, messages: &[ChatMessage]) -> Result<String, String> {
909        // Pick a cheap model from the user's configured providers
910        let use_stakpak = self.stakpak.is_some();
911        let providers = self.stakai.registry().list_providers();
912        let cheap_models: &[(&str, &str)] = &[
913            ("stakpak", "claude-haiku-4-5"),
914            ("anthropic", "claude-haiku-4-5"),
915            ("openai", "gpt-4.1-mini"),
916            ("google", "gemini-2.5-flash"),
917        ];
918        let model = cheap_models
919            .iter()
920            .find_map(|(provider, model_id)| {
921                if providers.contains(&provider.to_string()) {
922                    crate::find_model(model_id, use_stakpak)
923                } else {
924                    None
925                }
926            })
927            .ok_or_else(|| "No model available for title generation".to_string())?;
928
929        let llm_messages = vec![
930            LLMMessage {
931                role: Role::System.to_string(),
932                content: LLMMessageContent::String(TITLE_GENERATOR_PROMPT.to_string()),
933            },
934            LLMMessage {
935                role: Role::User.to_string(),
936                content: LLMMessageContent::String(
937                    messages
938                        .iter()
939                        .map(|msg| {
940                            msg.content
941                                .as_ref()
942                                .unwrap_or(&MessageContent::String("".to_string()))
943                                .to_string()
944                        })
945                        .collect(),
946                ),
947            },
948        ];
949
950        let input = LLMInput {
951            model,
952            messages: llm_messages,
953            max_tokens: 100,
954            tools: None,
955            provider_options: None,
956            headers: None,
957        };
958
959        let response = self.stakai.chat(input).await.map_err(|e| e.to_string())?;
960
961        Ok(response.choices[0].message.content.to_string())
962    }
963}