1use super::conversation::extension_model_from_entry;
2use super::*;
3
4#[derive(Clone)]
5pub(super) struct InteractiveExtensionHostActions {
6 pub(super) session: Arc<Mutex<Session>>,
7 pub(super) agent: Arc<Mutex<Agent>>,
8 pub(super) event_tx: mpsc::Sender<PiMsg>,
9 pub(super) extension_streaming: Arc<AtomicBool>,
10 pub(super) user_queue: Arc<StdMutex<InteractiveMessageQueue>>,
11 pub(super) injected_queue: Arc<StdMutex<InjectedMessageQueue>>,
12}
13
14impl InteractiveExtensionHostActions {
15 #[allow(clippy::unnecessary_wraps)]
16 fn queue_custom_message(
17 &self,
18 deliver_as: Option<ExtensionDeliverAs>,
19 message: ModelMessage,
20 ) -> crate::error::Result<()> {
21 let deliver_as = deliver_as.unwrap_or(ExtensionDeliverAs::Steer);
22 let kind = match deliver_as {
23 ExtensionDeliverAs::FollowUp => QueuedMessageKind::FollowUp,
24 ExtensionDeliverAs::Steer | ExtensionDeliverAs::NextTurn => QueuedMessageKind::Steering,
25 };
26 let Ok(mut queue) = self.injected_queue.lock() else {
27 return Ok(());
28 };
29 match kind {
30 QueuedMessageKind::Steering => queue.push_steering(message),
31 QueuedMessageKind::FollowUp => queue.push_follow_up(message),
32 }
33 Ok(())
34 }
35
36 async fn append_to_session(&self, message: ModelMessage) -> crate::error::Result<()> {
37 let cx = Cx::for_request();
38 let mut session_guard = self
39 .session
40 .lock(&cx)
41 .await
42 .map_err(|e| crate::error::Error::session(e.to_string()))?;
43 session_guard.append_model_message(message);
44 Ok(())
45 }
46}
47
48#[async_trait]
49impl ExtensionHostActions for InteractiveExtensionHostActions {
50 async fn send_message(&self, message: ExtensionSendMessage) -> crate::error::Result<()> {
51 let custom_message = ModelMessage::Custom(CustomMessage {
52 content: message.content,
53 custom_type: message.custom_type,
54 display: message.display,
55 details: message.details,
56 timestamp: Utc::now().timestamp_millis(),
57 });
58
59 let is_streaming = self.extension_streaming.load(Ordering::SeqCst);
60 if is_streaming {
61 self.queue_custom_message(message.deliver_as, custom_message.clone())?;
63 if let ModelMessage::Custom(custom) = &custom_message {
64 if custom.display {
65 let _ = self
66 .event_tx
67 .try_send(PiMsg::SystemNote(custom.content.clone()));
68 }
69 }
70 return Ok(());
71 }
72
73 let _ = message.trigger_turn;
76 self.append_to_session(custom_message.clone()).await?;
77
78 let cx = Cx::for_request();
79 if let Ok(mut agent_guard) = self.agent.lock(&cx).await {
80 agent_guard.add_message(custom_message.clone());
81 }
82
83 if let ModelMessage::Custom(custom) = &custom_message {
84 if custom.display {
85 let _ = self
86 .event_tx
87 .try_send(PiMsg::SystemNote(custom.content.clone()));
88 }
89 }
90
91 Ok(())
92 }
93
94 async fn send_user_message(
95 &self,
96 message: ExtensionSendUserMessage,
97 ) -> crate::error::Result<()> {
98 let is_streaming = self.extension_streaming.load(Ordering::SeqCst);
99 if is_streaming {
100 let deliver_as = message.deliver_as.unwrap_or(ExtensionDeliverAs::Steer);
101 let Ok(mut queue) = self.user_queue.lock() else {
102 return Ok(());
103 };
104 match deliver_as {
105 ExtensionDeliverAs::FollowUp => queue.push_follow_up(message.text),
106 ExtensionDeliverAs::Steer | ExtensionDeliverAs::NextTurn => {
107 queue.push_steering(message.text);
108 }
109 }
110 return Ok(());
111 }
112
113 let _ = self
114 .event_tx
115 .try_send(PiMsg::EnqueuePendingInput(PendingInput::Text(message.text)));
116 Ok(())
117 }
118}
119
120pub(super) struct InteractiveExtensionSession {
121 pub(super) session: Arc<Mutex<Session>>,
122 pub(super) model_entry: Arc<StdMutex<ModelEntry>>,
123 pub(super) is_streaming: Arc<AtomicBool>,
124 pub(super) is_compacting: Arc<AtomicBool>,
125 pub(super) config: Config,
126 pub(super) save_enabled: bool,
127}
128
129#[async_trait]
130impl ExtensionSession for InteractiveExtensionSession {
131 async fn get_state(&self) -> Value {
132 let model = {
133 let guard = self.model_entry.lock().unwrap();
134 extension_model_from_entry(&guard)
135 };
136
137 let cx = Cx::for_request();
138 let (
139 session_file,
140 session_id,
141 session_name,
142 message_count,
143 thinking_level,
144 durability_mode,
145 autosave_pending_mutations,
146 autosave_max_pending_mutations,
147 autosave_flush_failed_count,
148 autosave_backpressure,
149 persistence_status,
150 ) = self.session.lock(&cx).await.map_or_else(
151 |_| {
152 (
153 None,
154 String::new(),
155 None,
156 0,
157 "off".to_string(),
158 "balanced".to_string(),
159 0usize,
160 0usize,
161 0u64,
162 false,
163 "unknown".to_string(),
164 )
165 },
166 |guard| {
167 let message_count = guard
168 .entries_for_current_path()
169 .iter()
170 .filter(|entry| matches!(entry, SessionEntry::Message(_)))
171 .count();
172 let session_name = guard.get_name();
173 let thinking_level = guard
174 .header
175 .thinking_level
176 .clone()
177 .unwrap_or_else(|| "off".to_string());
178 let autosave_metrics = guard.autosave_metrics();
179 let durability_mode = guard.autosave_durability_mode().as_str().to_string();
180 let autosave_backpressure = autosave_metrics.max_pending_mutations > 0
181 && autosave_metrics.pending_mutations >= autosave_metrics.max_pending_mutations;
182 let persistence_status = if autosave_metrics.flush_failed > 0 {
183 "degraded"
184 } else if autosave_backpressure {
185 "backpressure"
186 } else if autosave_metrics.pending_mutations > 0 {
187 "draining"
188 } else {
189 "healthy"
190 }
191 .to_string();
192 (
193 guard.path.as_ref().map(|p| p.display().to_string()),
194 guard.header.id.clone(),
195 session_name,
196 message_count,
197 thinking_level,
198 durability_mode,
199 autosave_metrics.pending_mutations,
200 autosave_metrics.max_pending_mutations,
201 autosave_metrics.flush_failed,
202 autosave_backpressure,
203 persistence_status,
204 )
205 },
206 );
207
208 json!({
209 "model": model,
210 "thinkingLevel": thinking_level,
211 "isStreaming": self.is_streaming.load(Ordering::SeqCst),
212 "isCompacting": self.is_compacting.load(Ordering::SeqCst),
213 "steeringMode": "one-at-a-time",
214 "followUpMode": "one-at-a-time",
215 "sessionFile": session_file,
216 "sessionId": session_id,
217 "sessionName": session_name,
218 "autoCompactionEnabled": self.config.compaction_enabled(),
219 "messageCount": message_count,
220 "pendingMessageCount": autosave_pending_mutations,
221 "durabilityMode": durability_mode,
222 "autosavePendingMutations": autosave_pending_mutations,
223 "autosaveMaxPendingMutations": autosave_max_pending_mutations,
224 "autosaveFlushFailedCount": autosave_flush_failed_count,
225 "autosaveBackpressure": autosave_backpressure,
226 "persistenceStatus": persistence_status,
227 })
228 }
229
230 async fn get_messages(&self) -> Vec<SessionMessage> {
231 let cx = Cx::for_request();
232 let Ok(guard) = self.session.lock(&cx).await else {
233 return Vec::new();
234 };
235 guard
236 .entries_for_current_path()
237 .iter()
238 .filter_map(|entry| match entry {
239 SessionEntry::Message(msg) => match msg.message {
240 SessionMessage::User { .. }
241 | SessionMessage::Assistant { .. }
242 | SessionMessage::ToolResult { .. }
243 | SessionMessage::BashExecution { .. } => Some(msg.message.clone()),
244 _ => None,
245 },
246 _ => None,
247 })
248 .collect::<Vec<_>>()
249 }
250
251 async fn get_entries(&self) -> Vec<Value> {
252 let cx = Cx::for_request();
254 let Ok(guard) = self.session.lock(&cx).await else {
255 return Vec::new();
256 };
257 guard
258 .entries
259 .iter()
260 .filter_map(|entry| serde_json::to_value(entry).ok())
261 .collect()
262 }
263
264 async fn get_branch(&self) -> Vec<Value> {
265 let cx = Cx::for_request();
267 let Ok(guard) = self.session.lock(&cx).await else {
268 return Vec::new();
269 };
270 guard
271 .entries_for_current_path()
272 .iter()
273 .filter_map(|entry| serde_json::to_value(*entry).ok())
274 .collect()
275 }
276
277 async fn set_name(&self, name: String) -> crate::error::Result<()> {
278 let cx = Cx::for_request();
279 let mut guard =
280 self.session.lock(&cx).await.map_err(|err| {
281 crate::error::Error::session(format!("session lock failed: {err}"))
282 })?;
283 guard.set_name(&name);
284 if self.save_enabled {
285 guard.save().await?;
286 }
287 Ok(())
288 }
289
290 async fn append_message(&self, message: SessionMessage) -> crate::error::Result<()> {
291 let cx = Cx::for_request();
292 let mut guard =
293 self.session.lock(&cx).await.map_err(|err| {
294 crate::error::Error::session(format!("session lock failed: {err}"))
295 })?;
296 guard.append_message(message);
297 if self.save_enabled {
298 guard.save().await?;
299 }
300 Ok(())
301 }
302
303 async fn append_custom_entry(
304 &self,
305 custom_type: String,
306 data: Option<Value>,
307 ) -> crate::error::Result<()> {
308 if custom_type.trim().is_empty() {
309 return Err(crate::error::Error::validation(
310 "customType must not be empty",
311 ));
312 }
313 let cx = Cx::for_request();
314 let mut guard =
315 self.session.lock(&cx).await.map_err(|err| {
316 crate::error::Error::session(format!("session lock failed: {err}"))
317 })?;
318 guard.append_custom_entry(custom_type, data);
319 if self.save_enabled {
320 guard.save().await?;
321 }
322 Ok(())
323 }
324
325 async fn set_model(&self, provider: String, model_id: String) -> crate::error::Result<()> {
326 let cx = Cx::for_request();
327 let mut guard =
328 self.session.lock(&cx).await.map_err(|err| {
329 crate::error::Error::session(format!("session lock failed: {err}"))
330 })?;
331 guard.append_model_change(provider.clone(), model_id.clone());
332 guard.set_model_header(Some(provider), Some(model_id), None);
333 if self.save_enabled {
334 guard.save().await?;
335 }
336 Ok(())
337 }
338
339 async fn get_model(&self) -> (Option<String>, Option<String>) {
340 let cx = Cx::for_request();
341 let Ok(guard) = self.session.lock(&cx).await else {
342 return (None, None);
343 };
344 (guard.header.provider.clone(), guard.header.model_id.clone())
345 }
346
347 async fn set_thinking_level(&self, level: String) -> crate::error::Result<()> {
348 let cx = Cx::for_request();
349 let mut guard =
350 self.session.lock(&cx).await.map_err(|err| {
351 crate::error::Error::session(format!("session lock failed: {err}"))
352 })?;
353 guard.append_thinking_level_change(level.clone());
354 guard.set_model_header(None, None, Some(level));
355 if self.save_enabled {
356 guard.save().await?;
357 }
358 Ok(())
359 }
360
361 async fn get_thinking_level(&self) -> Option<String> {
362 let cx = Cx::for_request();
363 let Ok(guard) = self.session.lock(&cx).await else {
364 return None;
365 };
366 guard.header.thinking_level.clone()
367 }
368
369 async fn set_label(
370 &self,
371 target_id: String,
372 label: Option<String>,
373 ) -> crate::error::Result<()> {
374 let cx = Cx::for_request();
375 let mut guard =
376 self.session.lock(&cx).await.map_err(|err| {
377 crate::error::Error::session(format!("session lock failed: {err}"))
378 })?;
379 if guard.add_label(&target_id, label).is_none() {
380 return Err(crate::error::Error::validation(format!(
381 "target entry '{target_id}' not found in session"
382 )));
383 }
384 if self.save_enabled {
385 guard.save().await?;
386 }
387 Ok(())
388 }
389}
390
391pub fn format_extension_ui_prompt(request: &ExtensionUiRequest) -> String {
392 let title = request
393 .payload
394 .get("title")
395 .and_then(Value::as_str)
396 .unwrap_or("Extension");
397 let message = request
398 .payload
399 .get("message")
400 .and_then(Value::as_str)
401 .unwrap_or("");
402
403 let provenance = request
405 .extension_id
406 .as_deref()
407 .or_else(|| request.payload.get("extension_id").and_then(Value::as_str))
408 .unwrap_or("unknown");
409
410 match request.method.as_str() {
411 "confirm" => {
412 format!("[{provenance}] confirm: {title}\n{message}\n\nEnter yes/no, or 'cancel'.")
413 }
414 "select" => {
415 let options = request
416 .payload
417 .get("options")
418 .and_then(Value::as_array)
419 .cloned()
420 .unwrap_or_default();
421
422 let mut out = String::new();
423 let _ = writeln!(&mut out, "[{provenance}] select: {title}");
424 if !message.trim().is_empty() {
425 let _ = writeln!(&mut out, "{message}");
426 }
427 for (idx, opt) in options.iter().enumerate() {
428 let label = opt
429 .get("label")
430 .and_then(Value::as_str)
431 .or_else(|| opt.get("value").and_then(Value::as_str))
432 .or_else(|| opt.as_str())
433 .unwrap_or("");
434 let _ = writeln!(&mut out, " {}) {label}", idx + 1);
435 }
436 out.push_str("\nEnter a number, label, or 'cancel'.");
437 out
438 }
439 "input" => format!("[{provenance}] input: {title}\n{message}"),
440 "editor" => format!("[{provenance}] editor: {title}\n{message}"),
441 _ => format!("[{provenance}] {title} {message}"),
442 }
443}
444
445pub fn parse_extension_ui_response(
446 request: &ExtensionUiRequest,
447 input: &str,
448) -> Result<ExtensionUiResponse, String> {
449 let trimmed = input.trim();
450
451 if trimmed.eq_ignore_ascii_case("cancel") || trimmed.eq_ignore_ascii_case("c") {
452 return Ok(ExtensionUiResponse {
453 id: request.id.clone(),
454 value: None,
455 cancelled: true,
456 });
457 }
458
459 match request.method.as_str() {
460 "confirm" => {
461 let value = match trimmed.to_lowercase().as_str() {
462 "y" | "yes" | "true" | "1" => true,
463 "n" | "no" | "false" | "0" => false,
464 _ => {
465 return Err("Invalid confirmation. Enter yes/no, or 'cancel'.".to_string());
466 }
467 };
468 Ok(ExtensionUiResponse {
469 id: request.id.clone(),
470 value: Some(Value::Bool(value)),
471 cancelled: false,
472 })
473 }
474 "select" => {
475 let options = request
476 .payload
477 .get("options")
478 .and_then(Value::as_array)
479 .ok_or_else(|| {
480 "Invalid selection. Enter a number, label, or 'cancel'.".to_string()
481 })?;
482
483 if let Ok(index) = trimmed.parse::<usize>() {
484 if index > 0 && index <= options.len() {
485 let chosen = &options[index - 1];
486 let value = chosen
487 .get("value")
488 .cloned()
489 .or_else(|| chosen.get("label").cloned())
490 .or_else(|| chosen.as_str().map(|s| Value::String(s.to_string())));
491 return Ok(ExtensionUiResponse {
492 id: request.id.clone(),
493 value,
494 cancelled: false,
495 });
496 }
497 }
498
499 let lowered = trimmed.to_lowercase();
500 for option in options {
501 if let Some(value_str) = option.as_str() {
502 if value_str.to_lowercase() == lowered {
503 return Ok(ExtensionUiResponse {
504 id: request.id.clone(),
505 value: Some(Value::String(value_str.to_string())),
506 cancelled: false,
507 });
508 }
509 }
510
511 let label = option.get("label").and_then(Value::as_str).unwrap_or("");
512 if !label.is_empty() && label.to_lowercase() == lowered {
513 let value = option.get("value").cloned().or_else(|| {
514 option
515 .get("label")
516 .and_then(Value::as_str)
517 .map(|s| Value::String(s.to_string()))
518 });
519 return Ok(ExtensionUiResponse {
520 id: request.id.clone(),
521 value,
522 cancelled: false,
523 });
524 }
525
526 if let Some(value_str) = option.get("value").and_then(Value::as_str) {
527 if value_str.to_lowercase() == lowered {
528 return Ok(ExtensionUiResponse {
529 id: request.id.clone(),
530 value: Some(Value::String(value_str.to_string())),
531 cancelled: false,
532 });
533 }
534 }
535 }
536
537 Err("Invalid selection. Enter a number, label, or 'cancel'.".to_string())
538 }
539 _ => Ok(ExtensionUiResponse {
540 id: request.id.clone(),
541 value: Some(Value::String(input.to_string())),
542 cancelled: false,
543 }),
544 }
545}