Skip to main content

stakpak_api/client/
provider.rs

1//! AgentProvider trait implementation for AgentClient
2//!
3//! Implements the unified provider interface with:
4//! - Stakpak-first routing when API key is present
5//! - Local fallback when Stakpak is unavailable
6//! - Hook registry integration for lifecycle events
7
8use crate::AgentProvider;
9use crate::models::*;
10use crate::storage::{
11    CreateCheckpointRequest as StorageCreateCheckpointRequest,
12    CreateSessionRequest as StorageCreateSessionRequest,
13    UpdateSessionRequest as StorageUpdateSessionRequest,
14};
15use async_trait::async_trait;
16use futures_util::Stream;
17use reqwest::header::HeaderMap;
18use rmcp::model::Content;
19use stakai::Model;
20use stakpak_shared::hooks::{HookContext, LifecycleEvent};
21use stakpak_shared::models::integrations::openai::{
22    ChatCompletionChoice, ChatCompletionResponse, ChatCompletionStreamChoice,
23    ChatCompletionStreamResponse, ChatMessage, FinishReason, MessageContent, Role, Tool,
24};
25use stakpak_shared::models::llm::{
26    GenerationDelta, LLMInput, LLMMessage, LLMMessageContent, LLMStreamInput,
27};
28use std::pin::Pin;
29use tokio::sync::mpsc;
30use uuid::Uuid;
31
32/// Lightweight session info returned by initialize_session / save_checkpoint
33#[derive(Debug, Clone)]
34pub(crate) struct SessionInfo {
35    session_id: Uuid,
36    checkpoint_id: Uuid,
37    checkpoint_created_at: chrono::DateTime<chrono::Utc>,
38}
39
40use super::AgentClient;
41
42// =============================================================================
43// Internal Message Types
44// =============================================================================
45
46#[derive(Debug)]
47pub(crate) enum StreamMessage {
48    Delta(GenerationDelta),
49    Ctx(Box<HookContext<AgentState>>),
50}
51
52// =============================================================================
53// AgentProvider Implementation
54// =============================================================================
55
56#[async_trait]
57impl AgentProvider for AgentClient {
58    // =========================================================================
59    // Account
60    // =========================================================================
61
62    async fn get_my_account(&self) -> Result<GetMyAccountResponse, String> {
63        if let Some(api) = &self.stakpak_api {
64            api.get_account().await
65        } else {
66            // Local stub
67            Ok(GetMyAccountResponse {
68                username: "local".to_string(),
69                id: "local".to_string(),
70                first_name: "local".to_string(),
71                last_name: "local".to_string(),
72                email: "local@stakpak.dev".to_string(),
73                scope: None,
74            })
75        }
76    }
77
78    async fn get_billing_info(
79        &self,
80        account_username: &str,
81    ) -> Result<stakpak_shared::models::billing::BillingResponse, String> {
82        if let Some(api) = &self.stakpak_api {
83            api.get_billing(account_username).await
84        } else {
85            Err("Billing info not available without Stakpak API key".to_string())
86        }
87    }
88
89    // =========================================================================
90    // Rulebooks
91    // =========================================================================
92
93    async fn list_rulebooks(&self) -> Result<Vec<ListRuleBook>, String> {
94        if let Some(api) = &self.stakpak_api {
95            api.list_rulebooks().await
96        } else {
97            // Try to fetch public rulebooks via unauthenticated request
98            let client = stakpak_shared::tls_client::create_tls_client(
99                stakpak_shared::tls_client::TlsClientConfig::default()
100                    .with_timeout(std::time::Duration::from_secs(30)),
101            )?;
102
103            let url = format!("{}/v1/rules", self.get_stakpak_api_endpoint());
104            let response = client.get(&url).send().await.map_err(|e| e.to_string())?;
105
106            if response.status().is_success() {
107                let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
108                match serde_json::from_value::<ListRulebooksResponse>(value) {
109                    Ok(resp) => Ok(resp.results),
110                    Err(_) => Ok(vec![]),
111                }
112            } else {
113                Ok(vec![])
114            }
115        }
116    }
117
118    async fn get_rulebook_by_uri(&self, uri: &str) -> Result<RuleBook, String> {
119        if let Some(api) = &self.stakpak_api {
120            api.get_rulebook_by_uri(uri).await
121        } else {
122            // Try to fetch public rulebook via unauthenticated request
123            let client = stakpak_shared::tls_client::create_tls_client(
124                stakpak_shared::tls_client::TlsClientConfig::default()
125                    .with_timeout(std::time::Duration::from_secs(30)),
126            )?;
127
128            let encoded_uri = urlencoding::encode(uri);
129            let url = format!(
130                "{}/v1/rules/{}",
131                self.get_stakpak_api_endpoint(),
132                encoded_uri
133            );
134            let response = client.get(&url).send().await.map_err(|e| e.to_string())?;
135
136            if response.status().is_success() {
137                response.json().await.map_err(|e| e.to_string())
138            } else {
139                Err("Rulebook not found".to_string())
140            }
141        }
142    }
143
144    async fn create_rulebook(
145        &self,
146        uri: &str,
147        description: &str,
148        content: &str,
149        tags: Vec<String>,
150        visibility: Option<RuleBookVisibility>,
151    ) -> Result<CreateRuleBookResponse, String> {
152        if let Some(api) = &self.stakpak_api {
153            api.create_rulebook(&CreateRuleBookInput {
154                uri: uri.to_string(),
155                description: description.to_string(),
156                content: content.to_string(),
157                tags,
158                visibility,
159            })
160            .await
161        } else {
162            Err("Creating rulebooks requires Stakpak API key".to_string())
163        }
164    }
165
166    async fn delete_rulebook(&self, uri: &str) -> Result<(), String> {
167        if let Some(api) = &self.stakpak_api {
168            api.delete_rulebook(uri).await
169        } else {
170            Err("Deleting rulebooks requires Stakpak API key".to_string())
171        }
172    }
173
174    // =========================================================================
175    // Chat Completion
176    // =========================================================================
177
178    async fn chat_completion(
179        &self,
180        model: Model,
181        messages: Vec<ChatMessage>,
182        tools: Option<Vec<Tool>>,
183        session_id: Option<Uuid>,
184    ) -> Result<ChatCompletionResponse, String> {
185        let mut ctx = HookContext::new(session_id, AgentState::new(model, messages, tools));
186
187        // Execute before request hooks
188        self.hook_registry
189            .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
190            .await
191            .map_err(|e| e.to_string())?
192            .ok()?;
193
194        // Initialize or resume session
195        let current_session = self.initialize_session(&ctx).await?;
196        ctx.set_session_id(current_session.session_id);
197
198        // Run completion
199        let new_message = self.run_agent_completion(&mut ctx, None).await?;
200        ctx.state.append_new_message(new_message.clone());
201
202        // Save checkpoint
203        let result = self
204            .save_checkpoint(&current_session, ctx.state.messages.clone())
205            .await?;
206        let checkpoint_created_at = result.checkpoint_created_at.timestamp() as u64;
207        ctx.set_new_checkpoint_id(result.checkpoint_id);
208
209        // Execute after request hooks
210        self.hook_registry
211            .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
212            .await
213            .map_err(|e| e.to_string())?
214            .ok()?;
215
216        let mut meta = serde_json::Map::new();
217        if let Some(session_id) = ctx.session_id {
218            meta.insert(
219                "session_id".to_string(),
220                serde_json::Value::String(session_id.to_string()),
221            );
222        }
223        if let Some(checkpoint_id) = ctx.new_checkpoint_id {
224            meta.insert(
225                "checkpoint_id".to_string(),
226                serde_json::Value::String(checkpoint_id.to_string()),
227            );
228        }
229
230        Ok(ChatCompletionResponse {
231            id: ctx.new_checkpoint_id.unwrap().to_string(),
232            object: "chat.completion".to_string(),
233            created: checkpoint_created_at,
234            model: ctx
235                .state
236                .llm_input
237                .as_ref()
238                .map(|llm_input| llm_input.model.id.clone())
239                .unwrap_or_default(),
240            choices: vec![ChatCompletionChoice {
241                index: 0,
242                message: ctx.state.messages.last().cloned().unwrap(),
243                logprobs: None,
244                finish_reason: FinishReason::Stop,
245            }],
246            usage: ctx
247                .state
248                .llm_output
249                .as_ref()
250                .map(|u| u.usage.clone())
251                .unwrap_or_default(),
252            system_fingerprint: None,
253            metadata: if meta.is_empty() {
254                None
255            } else {
256                Some(serde_json::Value::Object(meta))
257            },
258        })
259    }
260
261    async fn chat_completion_stream(
262        &self,
263        model: Model,
264        messages: Vec<ChatMessage>,
265        tools: Option<Vec<Tool>>,
266        _headers: Option<HeaderMap>,
267        session_id: Option<Uuid>,
268    ) -> Result<
269        (
270            Pin<
271                Box<dyn Stream<Item = Result<ChatCompletionStreamResponse, ApiStreamError>> + Send>,
272            >,
273            Option<String>,
274        ),
275        String,
276    > {
277        let mut ctx = HookContext::new(session_id, AgentState::new(model, messages, tools));
278
279        // Execute before request hooks
280        self.hook_registry
281            .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
282            .await
283            .map_err(|e| e.to_string())?
284            .ok()?;
285
286        // Initialize session
287        let current_session = self.initialize_session(&ctx).await?;
288        ctx.set_session_id(current_session.session_id);
289
290        let (tx, mut rx) = mpsc::channel::<Result<StreamMessage, String>>(100);
291
292        // Clone what we need for the spawned task
293        let client = self.clone();
294        let mut ctx_clone = ctx.clone();
295
296        // Spawn the completion task with proper shutdown handling
297        // The task checks if the channel is closed before each expensive operation
298        // to support graceful shutdown when the stream consumer is dropped
299        tokio::spawn(async move {
300            // Check if consumer is still listening before starting
301            if tx.is_closed() {
302                return;
303            }
304
305            let result = client
306                .run_agent_completion(&mut ctx_clone, Some(tx.clone()))
307                .await;
308
309            match result {
310                Err(e) => {
311                    let _ = tx.send(Err(e)).await;
312                }
313                Ok(new_message) => {
314                    // Check if consumer is still listening before continuing
315                    if tx.is_closed() {
316                        return;
317                    }
318
319                    ctx_clone.state.append_new_message(new_message.clone());
320                    if tx
321                        .send(Ok(StreamMessage::Ctx(Box::new(ctx_clone.clone()))))
322                        .await
323                        .is_err()
324                    {
325                        // Consumer dropped, exit gracefully
326                        return;
327                    }
328
329                    // Check again before expensive session update
330                    if tx.is_closed() {
331                        return;
332                    }
333
334                    let result = client
335                        .save_checkpoint(&current_session, ctx_clone.state.messages.clone())
336                        .await;
337
338                    match result {
339                        Err(e) => {
340                            let _ = tx.send(Err(e)).await;
341                        }
342                        Ok(updated) => {
343                            ctx_clone.set_new_checkpoint_id(updated.checkpoint_id);
344                            let _ = tx.send(Ok(StreamMessage::Ctx(Box::new(ctx_clone)))).await;
345                        }
346                    }
347                }
348            }
349        });
350
351        let hook_registry = self.hook_registry.clone();
352        let stream = async_stream::stream! {
353            while let Some(delta_result) = rx.recv().await {
354                match delta_result {
355                    Ok(delta) => match delta {
356                        StreamMessage::Ctx(updated_ctx) => {
357                            ctx = *updated_ctx;
358                            // Emit session metadata so callers can track session_id
359                            if let Some(session_id) = ctx.session_id {
360                                let mut meta = serde_json::Map::new();
361                                meta.insert("session_id".to_string(), serde_json::Value::String(session_id.to_string()));
362                                if let Some(checkpoint_id) = ctx.new_checkpoint_id {
363                                    meta.insert("checkpoint_id".to_string(), serde_json::Value::String(checkpoint_id.to_string()));
364                                }
365                                yield Ok(ChatCompletionStreamResponse {
366                                    id: ctx.request_id.to_string(),
367                                    object: "chat.completion.chunk".to_string(),
368                                    created: chrono::Utc::now().timestamp() as u64,
369                                    model: String::new(),
370                                    choices: vec![],
371                                    usage: None,
372                                    metadata: Some(serde_json::Value::Object(meta)),
373                                });
374                            }
375                        }
376                        StreamMessage::Delta(delta) => {
377                            // Extract usage from Usage delta variant
378                            let usage = if let GenerationDelta::Usage { usage } = &delta {
379                                Some(usage.clone())
380                            } else {
381                                None
382                            };
383
384                            yield Ok(ChatCompletionStreamResponse {
385                                id: ctx.request_id.to_string(),
386                                object: "chat.completion.chunk".to_string(),
387                                created: chrono::Utc::now().timestamp() as u64,
388                                model: ctx.state.llm_input.as_ref().map(|llm_input| llm_input.model.clone().to_string()).unwrap_or_default(),
389                                choices: vec![ChatCompletionStreamChoice {
390                                    index: 0,
391                                    delta: delta.into(),
392                                    finish_reason: None,
393                                }],
394                                usage,
395                                metadata: None,
396                            })
397                        }
398                    }
399                    Err(e) => yield Err(ApiStreamError::Unknown(e)),
400                }
401            }
402
403            // Execute after request hooks
404            hook_registry
405                .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
406                .await
407                .map_err(|e| e.to_string())?
408                .ok()?;
409        };
410
411        Ok((Box::pin(stream), None))
412    }
413
414    async fn cancel_stream(&self, request_id: String) -> Result<(), String> {
415        if let Some(api) = &self.stakpak_api {
416            api.cancel_request(&request_id).await
417        } else {
418            // Local mode doesn't support cancellation yet
419            Ok(())
420        }
421    }
422
423    // =========================================================================
424    // Search Docs
425    // =========================================================================
426
427    async fn search_docs(&self, input: &SearchDocsRequest) -> Result<Vec<Content>, String> {
428        if let Some(api) = &self.stakpak_api {
429            api.search_docs(&crate::stakpak::SearchDocsRequest {
430                keywords: input.keywords.clone(),
431                exclude_keywords: input.exclude_keywords.clone(),
432                limit: input.limit,
433            })
434            .await
435        } else {
436            // Fallback to local search service
437            use stakpak_shared::models::integrations::search_service::*;
438
439            let config = SearchServicesOrchestrator::start()
440                .await
441                .map_err(|e| e.to_string())?;
442
443            let api_url = format!("http://localhost:{}", config.api_port);
444            let search_client = SearchClient::new(api_url);
445
446            let search_results = search_client
447                .search_and_scrape(input.keywords.clone(), None)
448                .await
449                .map_err(|e| e.to_string())?;
450
451            if search_results.is_empty() {
452                return Ok(vec![Content::text("No results found".to_string())]);
453            }
454
455            Ok(search_results
456                .into_iter()
457                .map(|result| {
458                    let content = result.content.unwrap_or_default();
459                    Content::text(format!("URL: {}\nContent: {}", result.url, content))
460                })
461                .collect())
462        }
463    }
464
465    // =========================================================================
466    // Memory
467    // =========================================================================
468
469    async fn memorize_session(&self, checkpoint_id: Uuid) -> Result<(), String> {
470        if let Some(api) = &self.stakpak_api {
471            api.memorize_session(checkpoint_id).await
472        } else {
473            // No-op in local mode
474            Ok(())
475        }
476    }
477
478    async fn search_memory(&self, input: &SearchMemoryRequest) -> Result<Vec<Content>, String> {
479        if let Some(api) = &self.stakpak_api {
480            api.search_memory(&crate::stakpak::SearchMemoryRequest {
481                keywords: input.keywords.clone(),
482                start_time: input.start_time,
483                end_time: input.end_time,
484            })
485            .await
486        } else {
487            // Empty results in local mode
488            Ok(vec![])
489        }
490    }
491
492    // =========================================================================
493    // Slack
494    // =========================================================================
495
496    async fn slack_read_messages(
497        &self,
498        input: &SlackReadMessagesRequest,
499    ) -> Result<Vec<Content>, String> {
500        if let Some(api) = &self.stakpak_api {
501            api.slack_read_messages(&crate::stakpak::SlackReadMessagesRequest {
502                channel: input.channel.clone(),
503                limit: input.limit,
504            })
505            .await
506        } else {
507            Err("Slack integration requires Stakpak API key".to_string())
508        }
509    }
510
511    async fn slack_read_replies(
512        &self,
513        input: &SlackReadRepliesRequest,
514    ) -> Result<Vec<Content>, String> {
515        if let Some(api) = &self.stakpak_api {
516            api.slack_read_replies(&crate::stakpak::SlackReadRepliesRequest {
517                channel: input.channel.clone(),
518                ts: input.ts.clone(),
519            })
520            .await
521        } else {
522            Err("Slack integration requires Stakpak API key".to_string())
523        }
524    }
525
526    async fn slack_send_message(
527        &self,
528        input: &SlackSendMessageRequest,
529    ) -> Result<Vec<Content>, String> {
530        if let Some(api) = &self.stakpak_api {
531            api.slack_send_message(&crate::stakpak::SlackSendMessageRequest {
532                channel: input.channel.clone(),
533                markdown_text: input.markdown_text.clone(),
534                thread_ts: input.thread_ts.clone(),
535            })
536            .await
537        } else {
538            Err("Slack integration requires Stakpak API key".to_string())
539        }
540    }
541
542    // =========================================================================
543    // Models
544    // =========================================================================
545
546    async fn list_models(&self) -> Vec<stakai::Model> {
547        const PROVIDERS: &[&str] = &["anthropic", "openai", "google"];
548
549        let use_stakpak = self.has_stakpak();
550        let mut all_models = Vec::new();
551
552        for &provider_id in PROVIDERS {
553            let mut models = load_and_transform_models(provider_id, use_stakpak);
554            sort_models_by_recency(&mut models);
555            all_models.extend(models);
556        }
557
558        all_models
559    }
560}
561
562/// Load models for a provider from cache, optionally transforming for Stakpak routing
563fn load_and_transform_models(provider_id: &str, use_stakpak: bool) -> Vec<stakai::Model> {
564    let models = stakai::load_models_for_provider(provider_id).unwrap_or_default();
565
566    if use_stakpak {
567        models
568            .into_iter()
569            .map(|m| stakai::Model {
570                id: format!("{}/{}", provider_id, m.id),
571                provider: "stakpak".into(),
572                name: m.name,
573                reasoning: m.reasoning,
574                cost: m.cost,
575                limit: m.limit,
576                release_date: m.release_date,
577            })
578            .collect()
579    } else {
580        models
581    }
582}
583
584/// Sort models by release_date descending (newest first)
585fn sort_models_by_recency(models: &mut [stakai::Model]) {
586    models.sort_by(|a, b| {
587        match (&b.release_date, &a.release_date) {
588            (Some(b_date), Some(a_date)) => b_date.cmp(a_date),
589            (Some(_), None) => std::cmp::Ordering::Less,
590            (None, Some(_)) => std::cmp::Ordering::Greater,
591            (None, None) => b.id.cmp(&a.id), // Fallback to ID descending
592        }
593    });
594}
595
596// =============================================================================
597// SessionStorage implementation (delegates to inner session_storage)
598// =============================================================================
599
600#[async_trait]
601impl crate::storage::SessionStorage for super::AgentClient {
602    async fn list_sessions(
603        &self,
604        query: &crate::storage::ListSessionsQuery,
605    ) -> Result<crate::storage::ListSessionsResult, crate::storage::StorageError> {
606        self.session_storage.list_sessions(query).await
607    }
608
609    async fn get_session(
610        &self,
611        session_id: Uuid,
612    ) -> Result<crate::storage::Session, crate::storage::StorageError> {
613        self.session_storage.get_session(session_id).await
614    }
615
616    async fn create_session(
617        &self,
618        request: &crate::storage::CreateSessionRequest,
619    ) -> Result<crate::storage::CreateSessionResult, crate::storage::StorageError> {
620        self.session_storage.create_session(request).await
621    }
622
623    async fn update_session(
624        &self,
625        session_id: Uuid,
626        request: &crate::storage::UpdateSessionRequest,
627    ) -> Result<crate::storage::Session, crate::storage::StorageError> {
628        self.session_storage
629            .update_session(session_id, request)
630            .await
631    }
632
633    async fn delete_session(&self, session_id: Uuid) -> Result<(), crate::storage::StorageError> {
634        self.session_storage.delete_session(session_id).await
635    }
636
637    async fn list_checkpoints(
638        &self,
639        session_id: Uuid,
640        query: &crate::storage::ListCheckpointsQuery,
641    ) -> Result<crate::storage::ListCheckpointsResult, crate::storage::StorageError> {
642        self.session_storage
643            .list_checkpoints(session_id, query)
644            .await
645    }
646
647    async fn get_checkpoint(
648        &self,
649        checkpoint_id: Uuid,
650    ) -> Result<crate::storage::Checkpoint, crate::storage::StorageError> {
651        self.session_storage.get_checkpoint(checkpoint_id).await
652    }
653
654    async fn create_checkpoint(
655        &self,
656        session_id: Uuid,
657        request: &crate::storage::CreateCheckpointRequest,
658    ) -> Result<crate::storage::Checkpoint, crate::storage::StorageError> {
659        self.session_storage
660            .create_checkpoint(session_id, request)
661            .await
662    }
663
664    async fn get_active_checkpoint(
665        &self,
666        session_id: Uuid,
667    ) -> Result<crate::storage::Checkpoint, crate::storage::StorageError> {
668        self.session_storage.get_active_checkpoint(session_id).await
669    }
670
671    async fn get_session_stats(
672        &self,
673        session_id: Uuid,
674    ) -> Result<crate::storage::SessionStats, crate::storage::StorageError> {
675        self.session_storage.get_session_stats(session_id).await
676    }
677}
678
679// =============================================================================
680// Helper Methods
681// =============================================================================
682
683const TITLE_GENERATOR_PROMPT: &str =
684    include_str!("../local/prompts/session_title_generator.v1.txt");
685
686impl AgentClient {
687    /// Initialize or resume a session based on context
688    ///
689    /// If `ctx.session_id` is set, we resume that session directly.
690    /// Otherwise, we create a new session.
691    pub(crate) async fn initialize_session(
692        &self,
693        ctx: &HookContext<AgentState>,
694    ) -> Result<SessionInfo, String> {
695        let messages = &ctx.state.messages;
696
697        if messages.is_empty() {
698            return Err("At least one message is required".to_string());
699        }
700
701        // If session_id is set in context, resume that session directly
702        if let Some(session_id) = ctx.session_id {
703            let session = self
704                .session_storage
705                .get_session(session_id)
706                .await
707                .map_err(|e| e.to_string())?;
708
709            let checkpoint = session
710                .active_checkpoint
711                .ok_or_else(|| format!("Session {} has no active checkpoint", session_id))?;
712
713            // If the session still has the default title, generate a better one in the background.
714            if session.title.trim().is_empty() || session.title == "New Session" {
715                let client = self.clone();
716                let messages_for_title = messages.to_vec();
717                let session_id = session.id;
718                let existing_title = session.title.clone();
719                tokio::spawn(async move {
720                    if let Ok(title) = client.generate_session_title(&messages_for_title).await {
721                        let trimmed = title.trim();
722                        if !trimmed.is_empty() && trimmed != existing_title {
723                            let request =
724                                StorageUpdateSessionRequest::new().with_title(trimmed.to_string());
725                            let _ = client
726                                .session_storage
727                                .update_session(session_id, &request)
728                                .await;
729                        }
730                    }
731                });
732            }
733
734            return Ok(SessionInfo {
735                session_id: session.id,
736                checkpoint_id: checkpoint.id,
737                checkpoint_created_at: checkpoint.created_at,
738            });
739        }
740
741        // Create new session with a fast local title.
742        let fallback_title = Self::fallback_session_title(messages);
743
744        // Get current working directory
745        let cwd = std::env::current_dir()
746            .ok()
747            .map(|p| p.to_string_lossy().to_string());
748
749        // Create session via storage trait
750        let mut session_request =
751            StorageCreateSessionRequest::new(fallback_title.clone(), messages.to_vec());
752        if let Some(cwd) = cwd {
753            session_request = session_request.with_cwd(cwd);
754        }
755
756        let result = self
757            .session_storage
758            .create_session(&session_request)
759            .await
760            .map_err(|e| e.to_string())?;
761
762        // Generate a better title asynchronously and update the session when ready.
763        let client = self.clone();
764        let messages_for_title = messages.to_vec();
765        let session_id = result.session_id;
766        tokio::spawn(async move {
767            if let Ok(title) = client.generate_session_title(&messages_for_title).await {
768                let trimmed = title.trim();
769                if !trimmed.is_empty() && trimmed != fallback_title {
770                    let request =
771                        StorageUpdateSessionRequest::new().with_title(trimmed.to_string());
772                    let _ = client
773                        .session_storage
774                        .update_session(session_id, &request)
775                        .await;
776                }
777            }
778        });
779
780        Ok(SessionInfo {
781            session_id: result.session_id,
782            checkpoint_id: result.checkpoint.id,
783            checkpoint_created_at: result.checkpoint.created_at,
784        })
785    }
786
787    fn fallback_session_title(messages: &[ChatMessage]) -> String {
788        messages
789            .iter()
790            .find(|m| m.role == Role::User)
791            .and_then(|m| m.content.as_ref())
792            .map(|c| {
793                let text = c.to_string();
794                text.split_whitespace()
795                    .take(5)
796                    .collect::<Vec<_>>()
797                    .join(" ")
798            })
799            .unwrap_or_else(|| "New Session".to_string())
800    }
801
802    /// Save a new checkpoint for the current session
803    pub(crate) async fn save_checkpoint(
804        &self,
805        current: &SessionInfo,
806        messages: Vec<ChatMessage>,
807    ) -> Result<SessionInfo, String> {
808        let checkpoint_request =
809            StorageCreateCheckpointRequest::new(messages).with_parent(current.checkpoint_id);
810
811        let checkpoint = self
812            .session_storage
813            .create_checkpoint(current.session_id, &checkpoint_request)
814            .await
815            .map_err(|e| e.to_string())?;
816
817        Ok(SessionInfo {
818            session_id: current.session_id,
819            checkpoint_id: checkpoint.id,
820            checkpoint_created_at: checkpoint.created_at,
821        })
822    }
823
824    /// Run agent completion (inference)
825    pub(crate) async fn run_agent_completion(
826        &self,
827        ctx: &mut HookContext<AgentState>,
828        stream_channel_tx: Option<mpsc::Sender<Result<StreamMessage, String>>>,
829    ) -> Result<ChatMessage, String> {
830        // Execute before inference hooks
831        self.hook_registry
832            .execute_hooks(ctx, &LifecycleEvent::BeforeInference)
833            .await
834            .map_err(|e| e.to_string())?
835            .ok()?;
836
837        let mut input = if let Some(llm_input) = ctx.state.llm_input.clone() {
838            llm_input
839        } else {
840            return Err(
841                "LLM input not found, make sure to register a context hook before inference"
842                    .to_string(),
843            );
844        };
845
846        // Inject session_id header if available
847        if let Some(session_id) = ctx.session_id {
848            let headers = input
849                .headers
850                .get_or_insert_with(std::collections::HashMap::new);
851            headers.insert("X-Session-Id".to_string(), session_id.to_string());
852        }
853
854        let (response_message, usage) = if let Some(tx) = stream_channel_tx {
855            // Streaming mode
856            let (internal_tx, mut internal_rx) = mpsc::channel::<GenerationDelta>(100);
857            let stream_input = LLMStreamInput {
858                model: input.model,
859                messages: input.messages,
860                max_tokens: input.max_tokens,
861                tools: input.tools,
862                stream_channel_tx: internal_tx,
863                provider_options: input.provider_options,
864                headers: input.headers,
865            };
866
867            let stakai = self.stakai.clone();
868            let chat_future = async move {
869                stakai
870                    .chat_stream(stream_input)
871                    .await
872                    .map_err(|e| e.to_string())
873            };
874
875            let receive_future = async move {
876                while let Some(delta) = internal_rx.recv().await {
877                    if tx.send(Ok(StreamMessage::Delta(delta))).await.is_err() {
878                        break;
879                    }
880                }
881            };
882
883            let (chat_result, _) = tokio::join!(chat_future, receive_future);
884            let response = chat_result?;
885            (response.choices[0].message.clone(), response.usage)
886        } else {
887            // Non-streaming mode
888            let response = self.stakai.chat(input).await.map_err(|e| e.to_string())?;
889            (response.choices[0].message.clone(), response.usage)
890        };
891
892        ctx.state.set_llm_output(response_message, usage);
893
894        // Execute after inference hooks
895        self.hook_registry
896            .execute_hooks(ctx, &LifecycleEvent::AfterInference)
897            .await
898            .map_err(|e| e.to_string())?
899            .ok()?;
900
901        let llm_output = ctx
902            .state
903            .llm_output
904            .as_ref()
905            .ok_or_else(|| "LLM output is missing from state".to_string())?;
906
907        Ok(ChatMessage::from(llm_output))
908    }
909
910    /// Generate a title for a new session
911    async fn generate_session_title(&self, messages: &[ChatMessage]) -> Result<String, String> {
912        // Use a default haiku model for title generation
913        let model = Model::new(
914            "claude-haiku-4-5-20250929",
915            "Claude Haiku 4.5",
916            "anthropic",
917            false,
918            None,
919            stakai::ModelLimit::default(),
920        );
921
922        let llm_messages = vec![
923            LLMMessage {
924                role: Role::System.to_string(),
925                content: LLMMessageContent::String(TITLE_GENERATOR_PROMPT.to_string()),
926            },
927            LLMMessage {
928                role: Role::User.to_string(),
929                content: LLMMessageContent::String(
930                    messages
931                        .iter()
932                        .map(|msg| {
933                            msg.content
934                                .as_ref()
935                                .unwrap_or(&MessageContent::String("".to_string()))
936                                .to_string()
937                        })
938                        .collect(),
939                ),
940            },
941        ];
942
943        let input = LLMInput {
944            model,
945            messages: llm_messages,
946            max_tokens: 100,
947            tools: None,
948            provider_options: None,
949            headers: None,
950        };
951
952        let response = self.stakai.chat(input).await.map_err(|e| e.to_string())?;
953
954        Ok(response.choices[0].message.content.to_string())
955    }
956}