Skip to main content

pi/interactive/
ext_session.rs

1use super::conversation::extension_model_from_entry;
2use super::*;
3
4#[derive(Clone)]
5pub(super) struct InteractiveExtensionHostActions {
6    pub(super) session: Arc<Mutex<Session>>,
7    pub(super) agent: Arc<Mutex<Agent>>,
8    pub(super) event_tx: mpsc::Sender<PiMsg>,
9    pub(super) extension_streaming: Arc<AtomicBool>,
10    pub(super) user_queue: Arc<StdMutex<InteractiveMessageQueue>>,
11    pub(super) injected_queue: Arc<StdMutex<InjectedMessageQueue>>,
12}
13
14impl InteractiveExtensionHostActions {
15    #[allow(clippy::unnecessary_wraps)]
16    fn queue_custom_message(
17        &self,
18        deliver_as: Option<ExtensionDeliverAs>,
19        message: ModelMessage,
20    ) -> crate::error::Result<()> {
21        let deliver_as = deliver_as.unwrap_or(ExtensionDeliverAs::Steer);
22        let kind = match deliver_as {
23            ExtensionDeliverAs::FollowUp => QueuedMessageKind::FollowUp,
24            ExtensionDeliverAs::Steer | ExtensionDeliverAs::NextTurn => QueuedMessageKind::Steering,
25        };
26        let Ok(mut queue) = self.injected_queue.lock() else {
27            return Ok(());
28        };
29        match kind {
30            QueuedMessageKind::Steering => queue.push_steering(message),
31            QueuedMessageKind::FollowUp => queue.push_follow_up(message),
32        }
33        Ok(())
34    }
35
36    async fn append_to_session(&self, message: ModelMessage) -> crate::error::Result<()> {
37        let cx = Cx::for_request();
38        let mut session_guard = self
39            .session
40            .lock(&cx)
41            .await
42            .map_err(|e| crate::error::Error::session(e.to_string()))?;
43        session_guard.append_model_message(message);
44        Ok(())
45    }
46}
47
48#[async_trait]
49impl ExtensionHostActions for InteractiveExtensionHostActions {
50    async fn send_message(&self, message: ExtensionSendMessage) -> crate::error::Result<()> {
51        let custom_message = ModelMessage::Custom(CustomMessage {
52            content: message.content,
53            custom_type: message.custom_type,
54            display: message.display,
55            details: message.details,
56            timestamp: Utc::now().timestamp_millis(),
57        });
58
59        let is_streaming = self.extension_streaming.load(Ordering::SeqCst);
60        if is_streaming {
61            // Queue into the agent loop; session persistence happens when the message is delivered.
62            self.queue_custom_message(message.deliver_as, custom_message.clone())?;
63            if let ModelMessage::Custom(custom) = &custom_message {
64                if custom.display {
65                    let _ = self
66                        .event_tx
67                        .try_send(PiMsg::SystemNote(custom.content.clone()));
68                }
69            }
70            return Ok(());
71        }
72
73        // Agent is idle: persist immediately and update in-memory history so it affects the next run.
74        // Triggering a new turn for custom messages is handled separately and may be implemented later.
75        let _ = message.trigger_turn;
76        self.append_to_session(custom_message.clone()).await?;
77
78        let cx = Cx::for_request();
79        if let Ok(mut agent_guard) = self.agent.lock(&cx).await {
80            agent_guard.add_message(custom_message.clone());
81        }
82
83        if let ModelMessage::Custom(custom) = &custom_message {
84            if custom.display {
85                let _ = self
86                    .event_tx
87                    .try_send(PiMsg::SystemNote(custom.content.clone()));
88            }
89        }
90
91        Ok(())
92    }
93
94    async fn send_user_message(
95        &self,
96        message: ExtensionSendUserMessage,
97    ) -> crate::error::Result<()> {
98        let is_streaming = self.extension_streaming.load(Ordering::SeqCst);
99        if is_streaming {
100            let deliver_as = message.deliver_as.unwrap_or(ExtensionDeliverAs::Steer);
101            let Ok(mut queue) = self.user_queue.lock() else {
102                return Ok(());
103            };
104            match deliver_as {
105                ExtensionDeliverAs::FollowUp => queue.push_follow_up(message.text),
106                ExtensionDeliverAs::Steer | ExtensionDeliverAs::NextTurn => {
107                    queue.push_steering(message.text);
108                }
109            }
110            return Ok(());
111        }
112
113        let _ = self
114            .event_tx
115            .try_send(PiMsg::EnqueuePendingInput(PendingInput::Text(message.text)));
116        Ok(())
117    }
118}
119
120pub(super) struct InteractiveExtensionSession {
121    pub(super) session: Arc<Mutex<Session>>,
122    pub(super) model_entry: Arc<StdMutex<ModelEntry>>,
123    pub(super) is_streaming: Arc<AtomicBool>,
124    pub(super) is_compacting: Arc<AtomicBool>,
125    pub(super) config: Config,
126    pub(super) save_enabled: bool,
127}
128
129#[async_trait]
130impl ExtensionSession for InteractiveExtensionSession {
131    async fn get_state(&self) -> Value {
132        let model = {
133            let guard = self.model_entry.lock().unwrap();
134            extension_model_from_entry(&guard)
135        };
136
137        let cx = Cx::for_request();
138        let (
139            session_file,
140            session_id,
141            session_name,
142            message_count,
143            thinking_level,
144            durability_mode,
145            autosave_pending_mutations,
146            autosave_max_pending_mutations,
147            autosave_flush_failed_count,
148            autosave_backpressure,
149            persistence_status,
150        ) = self.session.lock(&cx).await.map_or_else(
151            |_| {
152                (
153                    None,
154                    String::new(),
155                    None,
156                    0,
157                    "off".to_string(),
158                    "balanced".to_string(),
159                    0usize,
160                    0usize,
161                    0u64,
162                    false,
163                    "unknown".to_string(),
164                )
165            },
166            |guard| {
167                let message_count = guard
168                    .entries_for_current_path()
169                    .iter()
170                    .filter(|entry| matches!(entry, SessionEntry::Message(_)))
171                    .count();
172                let session_name = guard.get_name();
173                let thinking_level = guard
174                    .header
175                    .thinking_level
176                    .clone()
177                    .unwrap_or_else(|| "off".to_string());
178                let autosave_metrics = guard.autosave_metrics();
179                let durability_mode = guard.autosave_durability_mode().as_str().to_string();
180                let autosave_backpressure = autosave_metrics.max_pending_mutations > 0
181                    && autosave_metrics.pending_mutations >= autosave_metrics.max_pending_mutations;
182                let persistence_status = if autosave_metrics.flush_failed > 0 {
183                    "degraded"
184                } else if autosave_backpressure {
185                    "backpressure"
186                } else if autosave_metrics.pending_mutations > 0 {
187                    "draining"
188                } else {
189                    "healthy"
190                }
191                .to_string();
192                (
193                    guard.path.as_ref().map(|p| p.display().to_string()),
194                    guard.header.id.clone(),
195                    session_name,
196                    message_count,
197                    thinking_level,
198                    durability_mode,
199                    autosave_metrics.pending_mutations,
200                    autosave_metrics.max_pending_mutations,
201                    autosave_metrics.flush_failed,
202                    autosave_backpressure,
203                    persistence_status,
204                )
205            },
206        );
207
208        json!({
209            "model": model,
210            "thinkingLevel": thinking_level,
211            "isStreaming": self.is_streaming.load(Ordering::SeqCst),
212            "isCompacting": self.is_compacting.load(Ordering::SeqCst),
213            "steeringMode": "one-at-a-time",
214            "followUpMode": "one-at-a-time",
215            "sessionFile": session_file,
216            "sessionId": session_id,
217            "sessionName": session_name,
218            "autoCompactionEnabled": self.config.compaction_enabled(),
219            "messageCount": message_count,
220            "pendingMessageCount": autosave_pending_mutations,
221            "durabilityMode": durability_mode,
222            "autosavePendingMutations": autosave_pending_mutations,
223            "autosaveMaxPendingMutations": autosave_max_pending_mutations,
224            "autosaveFlushFailedCount": autosave_flush_failed_count,
225            "autosaveBackpressure": autosave_backpressure,
226            "persistenceStatus": persistence_status,
227        })
228    }
229
230    async fn get_messages(&self) -> Vec<SessionMessage> {
231        let cx = Cx::for_request();
232        let Ok(guard) = self.session.lock(&cx).await else {
233            return Vec::new();
234        };
235        guard
236            .entries_for_current_path()
237            .iter()
238            .filter_map(|entry| match entry {
239                SessionEntry::Message(msg) => match msg.message {
240                    SessionMessage::User { .. }
241                    | SessionMessage::Assistant { .. }
242                    | SessionMessage::ToolResult { .. }
243                    | SessionMessage::BashExecution { .. } => Some(msg.message.clone()),
244                    _ => None,
245                },
246                _ => None,
247            })
248            .collect::<Vec<_>>()
249    }
250
251    async fn get_entries(&self) -> Vec<Value> {
252        // Spec §3.1: return ALL session entries (entire session file), append order.
253        let cx = Cx::for_request();
254        let Ok(guard) = self.session.lock(&cx).await else {
255            return Vec::new();
256        };
257        guard
258            .entries
259            .iter()
260            .filter_map(|entry| serde_json::to_value(entry).ok())
261            .collect()
262    }
263
264    async fn get_branch(&self) -> Vec<Value> {
265        // Spec §3.2: return current path from root to leaf.
266        let cx = Cx::for_request();
267        let Ok(guard) = self.session.lock(&cx).await else {
268            return Vec::new();
269        };
270        guard
271            .entries_for_current_path()
272            .iter()
273            .filter_map(|entry| serde_json::to_value(*entry).ok())
274            .collect()
275    }
276
277    async fn set_name(&self, name: String) -> crate::error::Result<()> {
278        let cx = Cx::for_request();
279        let mut guard =
280            self.session.lock(&cx).await.map_err(|err| {
281                crate::error::Error::session(format!("session lock failed: {err}"))
282            })?;
283        guard.set_name(&name);
284        if self.save_enabled {
285            guard.save().await?;
286        }
287        Ok(())
288    }
289
290    async fn append_message(&self, message: SessionMessage) -> crate::error::Result<()> {
291        let cx = Cx::for_request();
292        let mut guard =
293            self.session.lock(&cx).await.map_err(|err| {
294                crate::error::Error::session(format!("session lock failed: {err}"))
295            })?;
296        guard.append_message(message);
297        if self.save_enabled {
298            guard.save().await?;
299        }
300        Ok(())
301    }
302
303    async fn append_custom_entry(
304        &self,
305        custom_type: String,
306        data: Option<Value>,
307    ) -> crate::error::Result<()> {
308        if custom_type.trim().is_empty() {
309            return Err(crate::error::Error::validation(
310                "customType must not be empty",
311            ));
312        }
313        let cx = Cx::for_request();
314        let mut guard =
315            self.session.lock(&cx).await.map_err(|err| {
316                crate::error::Error::session(format!("session lock failed: {err}"))
317            })?;
318        guard.append_custom_entry(custom_type, data);
319        if self.save_enabled {
320            guard.save().await?;
321        }
322        Ok(())
323    }
324
325    async fn set_model(&self, provider: String, model_id: String) -> crate::error::Result<()> {
326        let cx = Cx::for_request();
327        let mut guard =
328            self.session.lock(&cx).await.map_err(|err| {
329                crate::error::Error::session(format!("session lock failed: {err}"))
330            })?;
331        guard.append_model_change(provider.clone(), model_id.clone());
332        guard.set_model_header(Some(provider), Some(model_id), None);
333        if self.save_enabled {
334            guard.save().await?;
335        }
336        Ok(())
337    }
338
339    async fn get_model(&self) -> (Option<String>, Option<String>) {
340        let cx = Cx::for_request();
341        let Ok(guard) = self.session.lock(&cx).await else {
342            return (None, None);
343        };
344        (guard.header.provider.clone(), guard.header.model_id.clone())
345    }
346
347    async fn set_thinking_level(&self, level: String) -> crate::error::Result<()> {
348        let cx = Cx::for_request();
349        let mut guard =
350            self.session.lock(&cx).await.map_err(|err| {
351                crate::error::Error::session(format!("session lock failed: {err}"))
352            })?;
353        guard.append_thinking_level_change(level.clone());
354        guard.set_model_header(None, None, Some(level));
355        if self.save_enabled {
356            guard.save().await?;
357        }
358        Ok(())
359    }
360
361    async fn get_thinking_level(&self) -> Option<String> {
362        let cx = Cx::for_request();
363        let Ok(guard) = self.session.lock(&cx).await else {
364            return None;
365        };
366        guard.header.thinking_level.clone()
367    }
368
369    async fn set_label(
370        &self,
371        target_id: String,
372        label: Option<String>,
373    ) -> crate::error::Result<()> {
374        let cx = Cx::for_request();
375        let mut guard =
376            self.session.lock(&cx).await.map_err(|err| {
377                crate::error::Error::session(format!("session lock failed: {err}"))
378            })?;
379        if guard.add_label(&target_id, label).is_none() {
380            return Err(crate::error::Error::validation(format!(
381                "target entry '{target_id}' not found in session"
382            )));
383        }
384        if self.save_enabled {
385            guard.save().await?;
386        }
387        Ok(())
388    }
389}
390
391pub fn format_extension_ui_prompt(request: &ExtensionUiRequest) -> String {
392    let title = request
393        .payload
394        .get("title")
395        .and_then(Value::as_str)
396        .unwrap_or("Extension");
397    let message = request
398        .payload
399        .get("message")
400        .and_then(Value::as_str)
401        .unwrap_or("");
402
403    // Show provenance: which extension is making this request.
404    let provenance = request
405        .extension_id
406        .as_deref()
407        .or_else(|| request.payload.get("extension_id").and_then(Value::as_str))
408        .unwrap_or("unknown");
409
410    match request.method.as_str() {
411        "confirm" => {
412            format!("[{provenance}] confirm: {title}\n{message}\n\nEnter yes/no, or 'cancel'.")
413        }
414        "select" => {
415            let options = request
416                .payload
417                .get("options")
418                .and_then(Value::as_array)
419                .cloned()
420                .unwrap_or_default();
421
422            let mut out = String::new();
423            let _ = writeln!(&mut out, "[{provenance}] select: {title}");
424            if !message.trim().is_empty() {
425                let _ = writeln!(&mut out, "{message}");
426            }
427            for (idx, opt) in options.iter().enumerate() {
428                let label = opt
429                    .get("label")
430                    .and_then(Value::as_str)
431                    .or_else(|| opt.get("value").and_then(Value::as_str))
432                    .or_else(|| opt.as_str())
433                    .unwrap_or("");
434                let _ = writeln!(&mut out, "  {}) {label}", idx + 1);
435            }
436            out.push_str("\nEnter a number, label, or 'cancel'.");
437            out
438        }
439        "input" => format!("[{provenance}] input: {title}\n{message}"),
440        "editor" => format!("[{provenance}] editor: {title}\n{message}"),
441        _ => format!("[{provenance}] {title} {message}"),
442    }
443}
444
445pub fn parse_extension_ui_response(
446    request: &ExtensionUiRequest,
447    input: &str,
448) -> Result<ExtensionUiResponse, String> {
449    let trimmed = input.trim();
450
451    if trimmed.eq_ignore_ascii_case("cancel") || trimmed.eq_ignore_ascii_case("c") {
452        return Ok(ExtensionUiResponse {
453            id: request.id.clone(),
454            value: None,
455            cancelled: true,
456        });
457    }
458
459    match request.method.as_str() {
460        "confirm" => {
461            let value = match trimmed.to_lowercase().as_str() {
462                "y" | "yes" | "true" | "1" => true,
463                "n" | "no" | "false" | "0" => false,
464                _ => {
465                    return Err("Invalid confirmation. Enter yes/no, or 'cancel'.".to_string());
466                }
467            };
468            Ok(ExtensionUiResponse {
469                id: request.id.clone(),
470                value: Some(Value::Bool(value)),
471                cancelled: false,
472            })
473        }
474        "select" => {
475            let options = request
476                .payload
477                .get("options")
478                .and_then(Value::as_array)
479                .ok_or_else(|| {
480                    "Invalid selection. Enter a number, label, or 'cancel'.".to_string()
481                })?;
482
483            if let Ok(index) = trimmed.parse::<usize>() {
484                if index > 0 && index <= options.len() {
485                    let chosen = &options[index - 1];
486                    let value = chosen
487                        .get("value")
488                        .cloned()
489                        .or_else(|| chosen.get("label").cloned())
490                        .or_else(|| chosen.as_str().map(|s| Value::String(s.to_string())));
491                    return Ok(ExtensionUiResponse {
492                        id: request.id.clone(),
493                        value,
494                        cancelled: false,
495                    });
496                }
497            }
498
499            let lowered = trimmed.to_lowercase();
500            for option in options {
501                if let Some(value_str) = option.as_str() {
502                    if value_str.to_lowercase() == lowered {
503                        return Ok(ExtensionUiResponse {
504                            id: request.id.clone(),
505                            value: Some(Value::String(value_str.to_string())),
506                            cancelled: false,
507                        });
508                    }
509                }
510
511                let label = option.get("label").and_then(Value::as_str).unwrap_or("");
512                if !label.is_empty() && label.to_lowercase() == lowered {
513                    let value = option.get("value").cloned().or_else(|| {
514                        option
515                            .get("label")
516                            .and_then(Value::as_str)
517                            .map(|s| Value::String(s.to_string()))
518                    });
519                    return Ok(ExtensionUiResponse {
520                        id: request.id.clone(),
521                        value,
522                        cancelled: false,
523                    });
524                }
525
526                if let Some(value_str) = option.get("value").and_then(Value::as_str) {
527                    if value_str.to_lowercase() == lowered {
528                        return Ok(ExtensionUiResponse {
529                            id: request.id.clone(),
530                            value: Some(Value::String(value_str.to_string())),
531                            cancelled: false,
532                        });
533                    }
534                }
535            }
536
537            Err("Invalid selection. Enter a number, label, or 'cancel'.".to_string())
538        }
539        _ => Ok(ExtensionUiResponse {
540            id: request.id.clone(),
541            value: Some(Value::String(input.to_string())),
542            cancelled: false,
543        }),
544    }
545}