1use super::conversation::extension_model_from_entry;
2use super::*;
3use crate::provider_metadata::{canonical_provider_id, provider_ids_match};
4
5#[derive(Clone)]
6pub(super) struct InteractiveExtensionHostActions {
7 pub(super) session: Arc<Mutex<Session>>,
8 pub(super) agent: Arc<Mutex<Agent>>,
9 pub(super) event_tx: mpsc::Sender<PiMsg>,
10 pub(super) extension_streaming: Arc<AtomicBool>,
11 pub(super) user_queue: Arc<StdMutex<InteractiveMessageQueue>>,
12 pub(super) injected_queue: Arc<StdMutex<InjectedMessageQueue>>,
13}
14
15impl InteractiveExtensionHostActions {
16 const fn should_trigger_turn(
17 deliver_as: Option<ExtensionDeliverAs>,
18 trigger_turn: bool,
19 ) -> bool {
20 trigger_turn && !matches!(deliver_as, Some(ExtensionDeliverAs::NextTurn))
21 }
22
23 #[allow(clippy::unnecessary_wraps)]
24 fn queue_custom_message(
25 &self,
26 deliver_as: Option<ExtensionDeliverAs>,
27 message: ModelMessage,
28 ) -> crate::error::Result<()> {
29 let deliver_as = deliver_as.unwrap_or(ExtensionDeliverAs::Steer);
30 let kind = match deliver_as {
31 ExtensionDeliverAs::FollowUp => QueuedMessageKind::FollowUp,
32 ExtensionDeliverAs::Steer | ExtensionDeliverAs::NextTurn => QueuedMessageKind::Steering,
33 };
34 let Ok(mut queue) = self.injected_queue.lock() else {
35 return Ok(());
36 };
37 match kind {
38 QueuedMessageKind::Steering => queue.push_steering(message),
39 QueuedMessageKind::FollowUp => queue.push_follow_up(message),
40 }
41 Ok(())
42 }
43
44 async fn append_to_session(&self, message: ModelMessage) -> crate::error::Result<()> {
45 let cx = Cx::current().unwrap_or_else(Cx::for_request);
46 let mut session_guard = self
47 .session
48 .lock(&cx)
49 .await
50 .map_err(|e| crate::error::Error::session(e.to_string()))?;
51 session_guard.append_model_message(message);
52 Ok(())
53 }
54}
55
56#[async_trait]
57impl ExtensionHostActions for InteractiveExtensionHostActions {
58 async fn send_message(&self, message: ExtensionSendMessage) -> crate::error::Result<()> {
59 let custom_message = ModelMessage::Custom(CustomMessage {
60 content: message.content,
61 custom_type: message.custom_type,
62 display: message.display,
63 details: message.details,
64 timestamp: Utc::now().timestamp_millis(),
65 });
66 let cx = Cx::current().unwrap_or_else(Cx::for_request);
67
68 let is_streaming = self.extension_streaming.load(Ordering::SeqCst);
69 if is_streaming {
70 self.queue_custom_message(message.deliver_as, custom_message.clone())?;
72 if let ModelMessage::Custom(custom) = &custom_message {
73 if custom.display {
74 let _ = enqueue_pi_event(
75 &self.event_tx,
76 &cx,
77 PiMsg::SystemNote(custom.content.clone()),
78 )
79 .await;
80 }
81 }
82 return Ok(());
83 }
84
85 self.append_to_session(custom_message.clone()).await?;
87
88 if let Ok(mut agent_guard) = self.agent.lock(&cx).await {
89 agent_guard.add_message(custom_message.clone());
90 }
91
92 if let ModelMessage::Custom(custom) = &custom_message {
93 if custom.display {
94 let _ = enqueue_pi_event(
95 &self.event_tx,
96 &cx,
97 PiMsg::SystemNote(custom.content.clone()),
98 )
99 .await;
100 }
101 }
102
103 if Self::should_trigger_turn(message.deliver_as, message.trigger_turn) {
104 let _ = enqueue_pi_event(
105 &self.event_tx,
106 &cx,
107 PiMsg::EnqueuePendingInput(PendingInput::Continue),
108 )
109 .await;
110 }
111
112 Ok(())
113 }
114
115 async fn send_user_message(
116 &self,
117 message: ExtensionSendUserMessage,
118 ) -> crate::error::Result<()> {
119 let is_streaming = self.extension_streaming.load(Ordering::SeqCst);
120 if is_streaming {
121 let deliver_as = message.deliver_as.unwrap_or(ExtensionDeliverAs::Steer);
122 let Ok(mut queue) = self.user_queue.lock() else {
123 return Ok(());
124 };
125 match deliver_as {
126 ExtensionDeliverAs::FollowUp => queue.push_follow_up(message.text),
127 ExtensionDeliverAs::Steer | ExtensionDeliverAs::NextTurn => {
128 queue.push_steering(message.text);
129 }
130 }
131 return Ok(());
132 }
133
134 let cx = Cx::current().unwrap_or_else(Cx::for_request);
135 let _ = enqueue_pi_event(
136 &self.event_tx,
137 &cx,
138 PiMsg::EnqueuePendingInput(PendingInput::Text(message.text)),
139 )
140 .await;
141 Ok(())
142 }
143}
144
145pub(super) struct InteractiveExtensionSession {
146 pub(super) session: Arc<Mutex<Session>>,
147 pub(super) model_entry: Arc<StdMutex<ModelEntry>>,
148 pub(super) is_streaming: Arc<AtomicBool>,
149 pub(super) is_compacting: Arc<AtomicBool>,
150 pub(super) config: Config,
151 pub(super) save_enabled: bool,
152}
153
154fn current_path_model_pair(session: &Session) -> Option<(String, String)> {
155 session.effective_model_for_current_path()
156}
157
158fn current_path_model_fields(session: &Session) -> (Option<String>, Option<String>) {
159 if let Some((provider, model_id)) = current_path_model_pair(session) {
160 (Some(provider), Some(model_id))
161 } else {
162 (None, None)
163 }
164}
165
166fn current_path_thinking_level(session: &Session) -> Option<String> {
167 session.effective_thinking_level_for_current_path()
168}
169
170fn session_model_state_value(shared_model: &ModelEntry, session: &Session) -> Value {
171 match current_path_model_pair(session) {
172 Some((provider, model_id))
173 if provider_ids_match(&shared_model.model.provider, &provider)
174 && shared_model.model.id.eq_ignore_ascii_case(&model_id) =>
175 {
176 extension_model_from_entry(shared_model)
177 }
178 Some((provider, model_id)) => json!({
179 "provider": provider,
180 "id": model_id,
181 }),
182 None => Value::Null,
183 }
184}
185
186#[async_trait]
187impl ExtensionSession for InteractiveExtensionSession {
188 async fn get_state(&self) -> Value {
189 let shared_model = self
190 .model_entry
191 .lock()
192 .unwrap_or_else(std::sync::PoisonError::into_inner)
193 .clone();
194 let fallback_model = extension_model_from_entry(&shared_model);
195
196 let cx = Cx::current().unwrap_or_else(Cx::for_request);
197 let (
198 model,
199 session_file,
200 session_id,
201 session_name,
202 message_count,
203 thinking_level,
204 durability_mode,
205 autosave_pending_mutations,
206 autosave_max_pending_mutations,
207 autosave_flush_failed_count,
208 autosave_backpressure,
209 persistence_status,
210 ) = self.session.lock(&cx).await.map_or_else(
211 |_| {
212 (
213 fallback_model.clone(),
214 None,
215 String::new(),
216 None,
217 0,
218 "off".to_string(),
219 "balanced".to_string(),
220 0usize,
221 0usize,
222 0u64,
223 false,
224 "unknown".to_string(),
225 )
226 },
227 |guard| {
228 let model = session_model_state_value(&shared_model, &guard);
229 let message_count = guard
230 .entries_for_current_path()
231 .iter()
232 .filter(|entry| matches!(entry, SessionEntry::Message(_)))
233 .count();
234 let session_name = guard.get_name();
235 let thinking_level =
236 current_path_thinking_level(&guard).unwrap_or_else(|| "off".to_string());
237 let autosave_metrics = guard.autosave_metrics();
238 let durability_mode = guard.autosave_durability_mode().as_str().to_string();
239 let autosave_backpressure = autosave_metrics.max_pending_mutations > 0
240 && autosave_metrics.pending_mutations >= autosave_metrics.max_pending_mutations;
241 let persistence_status = if autosave_metrics.flush_failed > 0 {
242 "degraded"
243 } else if autosave_backpressure {
244 "backpressure"
245 } else if autosave_metrics.pending_mutations > 0 {
246 "draining"
247 } else {
248 "healthy"
249 }
250 .to_string();
251 (
252 model,
253 guard.path.as_ref().map(|p| p.display().to_string()),
254 guard.header.id.clone(),
255 session_name,
256 message_count,
257 thinking_level,
258 durability_mode,
259 autosave_metrics.pending_mutations,
260 autosave_metrics.max_pending_mutations,
261 autosave_metrics.flush_failed,
262 autosave_backpressure,
263 persistence_status,
264 )
265 },
266 );
267
268 json!({
269 "model": model,
270 "thinkingLevel": thinking_level,
271 "isStreaming": self.is_streaming.load(Ordering::SeqCst),
272 "isCompacting": self.is_compacting.load(Ordering::SeqCst),
273 "steeringMode": self.config.steering_queue_mode().as_str(),
274 "followUpMode": self.config.follow_up_queue_mode().as_str(),
275 "sessionFile": session_file,
276 "sessionId": session_id,
277 "sessionName": session_name,
278 "autoCompactionEnabled": self.config.compaction_enabled(),
279 "messageCount": message_count,
280 "pendingMessageCount": autosave_pending_mutations,
281 "durabilityMode": durability_mode,
282 "autosavePendingMutations": autosave_pending_mutations,
283 "autosaveMaxPendingMutations": autosave_max_pending_mutations,
284 "autosaveFlushFailedCount": autosave_flush_failed_count,
285 "autosaveBackpressure": autosave_backpressure,
286 "persistenceStatus": persistence_status,
287 })
288 }
289
290 async fn get_messages(&self) -> Vec<SessionMessage> {
291 let cx = Cx::current().unwrap_or_else(Cx::for_request);
292 let Ok(guard) = self.session.lock(&cx).await else {
293 return Vec::new();
294 };
295 guard
296 .entries_for_current_path()
297 .iter()
298 .filter_map(|entry| match entry {
299 SessionEntry::Message(msg) => match msg.message {
300 SessionMessage::User { .. }
301 | SessionMessage::Assistant { .. }
302 | SessionMessage::ToolResult { .. }
303 | SessionMessage::BashExecution { .. }
304 | SessionMessage::Custom { .. } => Some(msg.message.clone()),
305 _ => None,
306 },
307 _ => None,
308 })
309 .collect::<Vec<_>>()
310 }
311
312 async fn get_entries(&self) -> Vec<Value> {
313 let cx = Cx::current().unwrap_or_else(Cx::for_request);
315 let Ok(guard) = self.session.lock(&cx).await else {
316 return Vec::new();
317 };
318 guard
319 .entries
320 .iter()
321 .filter_map(|entry| serde_json::to_value(entry).ok())
322 .collect()
323 }
324
325 async fn get_branch(&self) -> Vec<Value> {
326 let cx = Cx::current().unwrap_or_else(Cx::for_request);
328 let Ok(guard) = self.session.lock(&cx).await else {
329 return Vec::new();
330 };
331 guard
332 .entries_for_current_path()
333 .iter()
334 .filter_map(|entry| serde_json::to_value(*entry).ok())
335 .collect()
336 }
337
338 async fn set_name(&self, name: String) -> crate::error::Result<()> {
339 let cx = Cx::current().unwrap_or_else(Cx::for_request);
340 let mut guard =
341 self.session.lock(&cx).await.map_err(|err| {
342 crate::error::Error::session(format!("session lock failed: {err}"))
343 })?;
344 guard.set_name(&name);
345 if self.save_enabled {
346 guard.save().await?;
347 }
348 Ok(())
349 }
350
351 async fn append_message(&self, message: SessionMessage) -> crate::error::Result<()> {
352 let cx = Cx::current().unwrap_or_else(Cx::for_request);
353 let mut guard =
354 self.session.lock(&cx).await.map_err(|err| {
355 crate::error::Error::session(format!("session lock failed: {err}"))
356 })?;
357 guard.append_message(message);
358 if self.save_enabled {
359 guard.save().await?;
360 }
361 Ok(())
362 }
363
364 async fn append_custom_entry(
365 &self,
366 custom_type: String,
367 data: Option<Value>,
368 ) -> crate::error::Result<()> {
369 if custom_type.trim().is_empty() {
370 return Err(crate::error::Error::validation(
371 "customType must not be empty",
372 ));
373 }
374 let cx = Cx::current().unwrap_or_else(Cx::for_request);
375 let mut guard =
376 self.session.lock(&cx).await.map_err(|err| {
377 crate::error::Error::session(format!("session lock failed: {err}"))
378 })?;
379 guard.append_custom_entry(custom_type, data);
380 if self.save_enabled {
381 guard.save().await?;
382 }
383 Ok(())
384 }
385
386 async fn set_model(&self, provider: String, model_id: String) -> crate::error::Result<()> {
387 let cx = Cx::current().unwrap_or_else(Cx::for_request);
388 let mut guard =
389 self.session.lock(&cx).await.map_err(|err| {
390 crate::error::Error::session(format!("session lock failed: {err}"))
391 })?;
392 let normalized_provider = canonical_provider_id(&provider)
393 .unwrap_or(&provider)
394 .to_string();
395 let (stored_provider, stored_model_id, changed) = match current_path_model_pair(&guard) {
396 Some((current_provider, current_model_id))
397 if provider_ids_match(¤t_provider, &provider)
398 && current_model_id.eq_ignore_ascii_case(&model_id) =>
399 {
400 (current_provider, current_model_id, false)
401 }
402 _ => (normalized_provider, model_id.clone(), true),
403 };
404 if changed {
405 guard.append_model_change(stored_provider.clone(), stored_model_id.clone());
406 }
407 guard.set_model_header(Some(stored_provider), Some(stored_model_id), None);
408 if self.save_enabled {
409 guard.save().await?;
410 }
411 Ok(())
412 }
413
414 async fn get_model(&self) -> (Option<String>, Option<String>) {
415 let cx = Cx::current().unwrap_or_else(Cx::for_request);
416 let Ok(guard) = self.session.lock(&cx).await else {
417 return (None, None);
418 };
419 current_path_model_fields(&guard)
420 }
421
422 async fn set_thinking_level(&self, level: String) -> crate::error::Result<()> {
423 let cx = Cx::current().unwrap_or_else(Cx::for_request);
424 let shared_model = self.model_entry.lock().map(|entry| entry.clone()).ok();
425 let mut guard =
426 self.session.lock(&cx).await.map_err(|err| {
427 crate::error::Error::session(format!("session lock failed: {err}"))
428 })?;
429 let effective_level = level.parse::<crate::model::ThinkingLevel>().map_or_else(
430 |_| level.clone(),
431 |parsed| match (shared_model.as_ref(), current_path_model_pair(&guard)) {
432 (Some(entry), Some((provider, model_id)))
433 if provider_ids_match(&entry.model.provider, &provider)
434 && entry.model.id.eq_ignore_ascii_case(&model_id) =>
435 {
436 entry.clamp_thinking_level(parsed).to_string()
437 }
438 (Some(entry), None) => entry.clamp_thinking_level(parsed).to_string(),
439 _ => level.clone(),
440 },
441 );
442 let changed =
443 current_path_thinking_level(&guard).as_deref() != Some(effective_level.as_str());
444 guard.set_model_header(None, None, Some(effective_level.clone()));
445 if changed {
446 guard.append_thinking_level_change(effective_level);
447 }
448 if changed && self.save_enabled {
449 guard.save().await?;
450 }
451 Ok(())
452 }
453
454 async fn get_thinking_level(&self) -> Option<String> {
455 let cx = Cx::current().unwrap_or_else(Cx::for_request);
456 let Ok(guard) = self.session.lock(&cx).await else {
457 return None;
458 };
459 current_path_thinking_level(&guard)
460 }
461
462 async fn set_label(
463 &self,
464 target_id: String,
465 label: Option<String>,
466 ) -> crate::error::Result<()> {
467 let cx = Cx::current().unwrap_or_else(Cx::for_request);
468 let mut guard =
469 self.session.lock(&cx).await.map_err(|err| {
470 crate::error::Error::session(format!("session lock failed: {err}"))
471 })?;
472 if guard.add_label(&target_id, label).is_none() {
473 return Err(crate::error::Error::validation(format!(
474 "target entry '{target_id}' not found in session"
475 )));
476 }
477 if self.save_enabled {
478 guard.save().await?;
479 }
480 Ok(())
481 }
482}
483
484pub fn format_extension_ui_prompt(request: &ExtensionUiRequest) -> String {
485 let title = request
486 .payload
487 .get("title")
488 .and_then(Value::as_str)
489 .unwrap_or("Extension");
490 let message = request
491 .payload
492 .get("message")
493 .and_then(Value::as_str)
494 .unwrap_or("");
495
496 let provenance = request
498 .extension_id
499 .as_deref()
500 .or_else(|| request.payload.get("extension_id").and_then(Value::as_str))
501 .unwrap_or("unknown");
502
503 match request.method.as_str() {
504 "confirm" => {
505 format!("[{provenance}] confirm: {title}\n{message}\n\nEnter yes/no, or 'cancel'.")
506 }
507 "select" => {
508 let options = request
509 .payload
510 .get("options")
511 .and_then(Value::as_array)
512 .cloned()
513 .unwrap_or_default();
514
515 let mut out = String::new();
516 let _ = writeln!(&mut out, "[{provenance}] select: {title}");
517 if !message.trim().is_empty() {
518 let _ = writeln!(&mut out, "{message}");
519 }
520 for (idx, opt) in options.iter().enumerate() {
521 let label = opt
522 .get("label")
523 .and_then(Value::as_str)
524 .or_else(|| opt.get("value").and_then(Value::as_str))
525 .or_else(|| opt.as_str())
526 .unwrap_or("");
527 let _ = writeln!(&mut out, " {}) {label}", idx + 1);
528 }
529 out.push_str("\nEnter a number, label, or 'cancel'.");
530 out
531 }
532 "input" => format!("[{provenance}] input: {title}\n{message}"),
533 "editor" => format!("[{provenance}] editor: {title}\n{message}"),
534 _ => format!("[{provenance}] {title} {message}"),
535 }
536}
537
538pub fn parse_extension_ui_response(
539 request: &ExtensionUiRequest,
540 input: &str,
541) -> Result<ExtensionUiResponse, String> {
542 let trimmed = input.trim();
543
544 if trimmed.eq_ignore_ascii_case("cancel") || trimmed.eq_ignore_ascii_case("c") {
545 return Ok(ExtensionUiResponse {
546 id: request.id.clone(),
547 value: None,
548 cancelled: true,
549 });
550 }
551
552 match request.method.as_str() {
553 "confirm" => {
554 let value = match trimmed.to_lowercase().as_str() {
555 "y" | "yes" | "true" | "1" => true,
556 "n" | "no" | "false" | "0" => false,
557 _ => {
558 return Err("Invalid confirmation. Enter yes/no, or 'cancel'.".to_string());
559 }
560 };
561 Ok(ExtensionUiResponse {
562 id: request.id.clone(),
563 value: Some(Value::Bool(value)),
564 cancelled: false,
565 })
566 }
567 "select" => {
568 let options = request
569 .payload
570 .get("options")
571 .and_then(Value::as_array)
572 .ok_or_else(|| {
573 "Invalid selection. Enter a number, label, or 'cancel'.".to_string()
574 })?;
575
576 if let Ok(index) = trimmed.parse::<usize>() {
577 if index > 0 && index <= options.len() {
578 let chosen = &options[index - 1];
579 let value = chosen
580 .get("value")
581 .cloned()
582 .or_else(|| chosen.get("label").cloned())
583 .or_else(|| chosen.as_str().map(|s| Value::String(s.to_string())));
584 return Ok(ExtensionUiResponse {
585 id: request.id.clone(),
586 value,
587 cancelled: false,
588 });
589 }
590 }
591
592 let lowered = trimmed.to_lowercase();
593 for option in options {
594 if let Some(value_str) = option.as_str() {
595 if value_str.to_lowercase() == lowered {
596 return Ok(ExtensionUiResponse {
597 id: request.id.clone(),
598 value: Some(Value::String(value_str.to_string())),
599 cancelled: false,
600 });
601 }
602 }
603
604 let label = option.get("label").and_then(Value::as_str).unwrap_or("");
605 if !label.is_empty() && label.to_lowercase() == lowered {
606 let value = option.get("value").cloned().or_else(|| {
607 option
608 .get("label")
609 .and_then(Value::as_str)
610 .map(|s| Value::String(s.to_string()))
611 });
612 return Ok(ExtensionUiResponse {
613 id: request.id.clone(),
614 value,
615 cancelled: false,
616 });
617 }
618
619 if let Some(value_str) = option.get("value").and_then(Value::as_str) {
620 if value_str.to_lowercase() == lowered {
621 return Ok(ExtensionUiResponse {
622 id: request.id.clone(),
623 value: Some(Value::String(value_str.to_string())),
624 cancelled: false,
625 });
626 }
627 }
628 }
629
630 Err("Invalid selection. Enter a number, label, or 'cancel'.".to_string())
631 }
632 _ => Ok(ExtensionUiResponse {
633 id: request.id.clone(),
634 value: Some(Value::String(input.to_string())),
635 cancelled: false,
636 }),
637 }
638}
639
640#[cfg(test)]
641mod tests {
642 use super::*;
643
644 use crate::agent::{Agent, AgentConfig};
645 use crate::config::Config;
646 use crate::model::StreamEvent;
647 use crate::models::ModelEntry;
648 use crate::provider::{Context, InputType, Model, ModelCost, Provider, StreamOptions};
649 use crate::session::{Session, SessionMessage};
650 use crate::tools::ToolRegistry;
651 use asupersync::runtime::RuntimeBuilder;
652 use async_trait::async_trait;
653 use futures::stream;
654 use serde_json::json;
655 use std::collections::HashMap;
656 use std::path::Path;
657 use std::pin::Pin;
658 use std::time::Duration;
659
660 type TestStream =
661 Pin<Box<dyn futures::Stream<Item = crate::error::Result<StreamEvent>> + Send>>;
662 type HostActionsHarness = (
663 InteractiveExtensionHostActions,
664 mpsc::Receiver<PiMsg>,
665 Arc<Mutex<Session>>,
666 Arc<Mutex<Agent>>,
667 );
668
669 struct NoopProvider;
670
671 #[async_trait]
672 impl Provider for NoopProvider {
673 fn name(&self) -> &'static str {
674 "noop"
675 }
676
677 fn api(&self) -> &'static str {
678 "noop"
679 }
680
681 fn model_id(&self) -> &'static str {
682 "noop-model"
683 }
684
685 async fn stream(
686 &self,
687 _context: &Context<'_>,
688 _options: &StreamOptions,
689 ) -> crate::error::Result<TestStream> {
690 Ok(Box::pin(stream::empty()))
691 }
692 }
693
694 fn build_host_actions() -> HostActionsHarness {
695 build_host_actions_with_capacity(8)
696 }
697
698 fn build_host_actions_with_capacity(capacity: usize) -> HostActionsHarness {
699 let session = Arc::new(Mutex::new(Session::in_memory()));
700 let provider: Arc<dyn Provider> = Arc::new(NoopProvider);
701 let agent = Arc::new(Mutex::new(Agent::new(
702 provider,
703 ToolRegistry::new(&[], Path::new("."), None),
704 AgentConfig::default(),
705 )));
706 let (event_tx, event_rx) = mpsc::channel(capacity);
707 (
708 InteractiveExtensionHostActions {
709 session: Arc::clone(&session),
710 agent: Arc::clone(&agent),
711 event_tx,
712 extension_streaming: Arc::new(AtomicBool::new(false)),
713 user_queue: Arc::new(StdMutex::new(InteractiveMessageQueue::new(
714 QueueMode::OneAtATime,
715 QueueMode::OneAtATime,
716 ))),
717 injected_queue: Arc::new(StdMutex::new(InjectedMessageQueue::new(
718 QueueMode::OneAtATime,
719 QueueMode::OneAtATime,
720 ))),
721 },
722 event_rx,
723 session,
724 agent,
725 )
726 }
727
728 fn dummy_model_entry() -> ModelEntry {
729 ModelEntry {
730 model: Model {
731 id: "noop-model".to_string(),
732 name: "Noop Model".to_string(),
733 api: "noop".to_string(),
734 provider: "noop".to_string(),
735 base_url: "https://example.invalid".to_string(),
736 reasoning: false,
737 input: vec![InputType::Text],
738 cost: ModelCost {
739 input: 0.0,
740 output: 0.0,
741 cache_read: 0.0,
742 cache_write: 0.0,
743 },
744 context_window: 8192,
745 max_tokens: 1024,
746 headers: HashMap::new(),
747 },
748 api_key: None,
749 headers: HashMap::new(),
750 auth_header: true,
751 compat: None,
752 oauth_config: None,
753 }
754 }
755
756 #[test]
757 fn interactive_extension_session_get_messages_includes_custom_messages() {
758 let runtime = RuntimeBuilder::current_thread()
759 .build()
760 .expect("runtime build");
761
762 runtime.block_on(async {
763 let session = Arc::new(Mutex::new(Session::in_memory()));
764 let cx = Cx::for_request();
765 {
766 let mut guard = session.lock(&cx).await.expect("lock session");
767 guard.append_message(SessionMessage::Custom {
768 custom_type: "note".to_string(),
769 content: "hello".to_string(),
770 display: true,
771 details: Some(json!({ "from": "test" })),
772 timestamp: Some(1),
773 });
774 }
775
776 let ext_session = InteractiveExtensionSession {
777 session,
778 model_entry: Arc::new(StdMutex::new(dummy_model_entry())),
779 is_streaming: Arc::new(AtomicBool::new(false)),
780 is_compacting: Arc::new(AtomicBool::new(false)),
781 config: Config::default(),
782 save_enabled: false,
783 };
784
785 let messages = ext_session.get_messages().await;
786 assert!(
787 messages.iter().any(|message| {
788 matches!(
789 message,
790 SessionMessage::Custom {
791 custom_type,
792 content,
793 display,
794 details,
795 ..
796 } if custom_type == "note"
797 && content == "hello"
798 && *display
799 && details.as_ref().and_then(|value| value.get("from").and_then(Value::as_str))
800 == Some("test")
801 )
802 }),
803 "expected custom message in interactive extension session messages, got {messages:?}"
804 );
805 });
806 }
807
808 #[test]
809 fn interactive_extension_session_set_name_inherits_cancelled_context_when_locked() {
810 let runtime = RuntimeBuilder::current_thread()
811 .build()
812 .expect("runtime build");
813
814 runtime.block_on(async {
815 let session = Arc::new(Mutex::new(Session::in_memory()));
816 let ext_session = InteractiveExtensionSession {
817 session: Arc::clone(&session),
818 model_entry: Arc::new(StdMutex::new(dummy_model_entry())),
819 is_streaming: Arc::new(AtomicBool::new(false)),
820 is_compacting: Arc::new(AtomicBool::new(false)),
821 config: Config::default(),
822 save_enabled: false,
823 };
824
825 let hold_cx = Cx::for_request();
826 let held_guard = session.lock(&hold_cx).await.expect("lock session");
827
828 let ambient_cx = Cx::for_testing();
829 ambient_cx.set_cancel_requested(true);
830 let _current = Cx::set_current(Some(ambient_cx));
831 let inner = asupersync::time::timeout(
832 asupersync::time::wall_now(),
833 Duration::from_millis(100),
834 ext_session.set_name("cancelled-name".to_string()),
835 )
836 .await;
837 let outcome = inner.expect("cancelled helper should finish before timeout");
838 let err = outcome.expect_err("lock acquisition should honor inherited cancellation");
839 assert!(
840 err.to_string().contains("session lock failed"),
841 "unexpected error: {err}"
842 );
843
844 drop(held_guard);
845
846 let cx = Cx::for_request();
847 let guard = session.lock(&cx).await.expect("lock session");
848 assert_eq!(guard.get_name(), None);
849 });
850 }
851
852 #[test]
853 fn idle_send_message_trigger_turn_enqueues_continue() {
854 let runtime = RuntimeBuilder::current_thread()
855 .build()
856 .expect("runtime build");
857
858 runtime.block_on(async {
859 let (actions, mut event_rx, session, agent) = build_host_actions();
860
861 actions
862 .send_message(ExtensionSendMessage {
863 extension_id: Some("ext".to_string()),
864 custom_type: "note".to_string(),
865 content: "continue-now".to_string(),
866 display: false,
867 details: None,
868 deliver_as: Some(ExtensionDeliverAs::Steer),
869 trigger_turn: true,
870 })
871 .await
872 .expect("send_message");
873
874 let queued = event_rx.try_recv().expect("continue should be queued");
875 assert!(matches!(
876 queued,
877 PiMsg::EnqueuePendingInput(PendingInput::Continue)
878 ));
879
880 let cx = Cx::for_request();
881 let session_guard = session.lock(&cx).await.expect("lock session");
882 assert!(
883 session_guard
884 .to_messages_for_current_path()
885 .iter()
886 .any(|msg| {
887 matches!(
888 msg,
889 ModelMessage::Custom(CustomMessage { custom_type, content, .. })
890 if custom_type == "note" && content == "continue-now"
891 )
892 })
893 );
894 drop(session_guard);
895
896 let agent_guard = agent.lock(&cx).await.expect("lock agent");
897 assert!(agent_guard.messages().iter().any(|msg| {
898 matches!(
899 msg,
900 ModelMessage::Custom(CustomMessage { custom_type, content, .. })
901 if custom_type == "note" && content == "continue-now"
902 )
903 }));
904 });
905 }
906
907 #[test]
908 fn idle_send_message_next_turn_ignores_trigger_turn() {
909 let runtime = RuntimeBuilder::current_thread()
910 .build()
911 .expect("runtime build");
912
913 runtime.block_on(async {
914 let (actions, mut event_rx, _session, _agent) = build_host_actions();
915
916 actions
917 .send_message(ExtensionSendMessage {
918 extension_id: Some("ext".to_string()),
919 custom_type: "note".to_string(),
920 content: "defer".to_string(),
921 display: false,
922 details: None,
923 deliver_as: Some(ExtensionDeliverAs::NextTurn),
924 trigger_turn: true,
925 })
926 .await
927 .expect("send_message");
928
929 assert!(
930 event_rx.try_recv().is_err(),
931 "nextTurn should stay deferred even when triggerTurn is set"
932 );
933 });
934 }
935
936 #[test]
937 fn streaming_send_message_preserves_display_note_under_backpressure() {
938 let runtime = RuntimeBuilder::current_thread()
939 .build()
940 .expect("runtime build");
941
942 runtime.block_on(async {
943 let (actions, mut event_rx, _session, _agent) = build_host_actions_with_capacity(1);
944 actions.extension_streaming.store(true, Ordering::SeqCst);
945 actions
946 .event_tx
947 .try_send(PiMsg::System("busy".to_string()))
948 .expect("fill bounded event channel");
949
950 let send_message = actions.send_message(ExtensionSendMessage {
951 extension_id: Some("ext".to_string()),
952 custom_type: "note".to_string(),
953 content: "visible".to_string(),
954 display: true,
955 details: None,
956 deliver_as: Some(ExtensionDeliverAs::Steer),
957 trigger_turn: false,
958 });
959 let recv_cx = Cx::for_request();
960 let recv_messages = async {
961 let first = event_rx.recv(&recv_cx).await.expect("busy message");
962 let second = event_rx.recv(&recv_cx).await.expect("display note");
963 (first, second)
964 };
965
966 let (result, (first, second)) = futures::join!(send_message, recv_messages);
967
968 result.expect("send_message");
969 assert!(matches!(first, PiMsg::System(text) if text == "busy"));
970 assert!(matches!(second, PiMsg::SystemNote(text) if text == "visible"));
971 });
972 }
973
974 #[test]
975 fn idle_send_message_preserves_display_and_continue_under_backpressure() {
976 let runtime = RuntimeBuilder::current_thread()
977 .build()
978 .expect("runtime build");
979
980 runtime.block_on(async {
981 let (actions, mut event_rx, _session, _agent) = build_host_actions_with_capacity(1);
982 actions
983 .event_tx
984 .try_send(PiMsg::System("busy".to_string()))
985 .expect("fill bounded event channel");
986
987 let send_message = actions.send_message(ExtensionSendMessage {
988 extension_id: Some("ext".to_string()),
989 custom_type: "note".to_string(),
990 content: "continue-now".to_string(),
991 display: true,
992 details: None,
993 deliver_as: Some(ExtensionDeliverAs::Steer),
994 trigger_turn: true,
995 });
996 let recv_cx = Cx::for_request();
997 let recv_messages = async {
998 let first = event_rx.recv(&recv_cx).await.expect("busy message");
999 let second = event_rx.recv(&recv_cx).await.expect("display note");
1000 let third = event_rx.recv(&recv_cx).await.expect("continue enqueue");
1001 (first, second, third)
1002 };
1003
1004 let (result, (first, second, third)) = futures::join!(send_message, recv_messages);
1005
1006 result.expect("send_message");
1007 assert!(matches!(first, PiMsg::System(text) if text == "busy"));
1008 assert!(matches!(second, PiMsg::SystemNote(text) if text == "continue-now"));
1009 assert!(matches!(
1010 third,
1011 PiMsg::EnqueuePendingInput(PendingInput::Continue)
1012 ));
1013 });
1014 }
1015
1016 #[test]
1017 fn idle_send_user_message_preserves_text_under_backpressure() {
1018 let runtime = RuntimeBuilder::current_thread()
1019 .build()
1020 .expect("runtime build");
1021
1022 runtime.block_on(async {
1023 let (actions, mut event_rx, _session, _agent) = build_host_actions_with_capacity(1);
1024 actions
1025 .event_tx
1026 .try_send(PiMsg::System("busy".to_string()))
1027 .expect("fill bounded event channel");
1028
1029 let send_message = actions.send_user_message(ExtensionSendUserMessage {
1030 extension_id: Some("ext".to_string()),
1031 text: "hello from extension".to_string(),
1032 deliver_as: None,
1033 });
1034 let recv_cx = Cx::for_request();
1035 let recv_messages = async {
1036 let first = event_rx.recv(&recv_cx).await.expect("busy message");
1037 let second = event_rx.recv(&recv_cx).await.expect("user input enqueue");
1038 (first, second)
1039 };
1040
1041 let (result, (first, second)) = futures::join!(send_message, recv_messages);
1042
1043 result.expect("send_user_message");
1044 assert!(matches!(first, PiMsg::System(text) if text == "busy"));
1045 assert!(matches!(
1046 second,
1047 PiMsg::EnqueuePendingInput(PendingInput::Text(text))
1048 if text == "hello from extension"
1049 ));
1050 });
1051 }
1052
1053 #[test]
1054 fn set_thinking_level_clamps_and_dedupes_for_non_reasoning_models() {
1055 let runtime = RuntimeBuilder::current_thread()
1056 .build()
1057 .expect("runtime build");
1058
1059 runtime.block_on(async {
1060 let mut entry = dummy_model_entry();
1061 entry.model.reasoning = false;
1062 let session = Arc::new(Mutex::new(Session::in_memory()));
1063 let ext_session = InteractiveExtensionSession {
1064 session: Arc::clone(&session),
1065 model_entry: Arc::new(StdMutex::new(entry)),
1066 is_streaming: Arc::new(AtomicBool::new(false)),
1067 is_compacting: Arc::new(AtomicBool::new(false)),
1068 config: Config::default(),
1069 save_enabled: false,
1070 };
1071
1072 ext_session
1073 .set_thinking_level("high".to_string())
1074 .await
1075 .expect("first thinking update");
1076 ext_session
1077 .set_thinking_level("high".to_string())
1078 .await
1079 .expect("second thinking update");
1080
1081 let cx = Cx::for_request();
1082 let guard = session.lock(&cx).await.expect("lock session");
1083 assert_eq!(guard.header.thinking_level.as_deref(), Some("off"));
1084 let thinking_changes = guard
1085 .entries_for_current_path()
1086 .iter()
1087 .filter(|entry| {
1088 matches!(entry, crate::session::SessionEntry::ThinkingLevelChange(_))
1089 })
1090 .count();
1091 assert_eq!(thinking_changes, 1);
1092 });
1093 }
1094
1095 #[test]
1096 fn set_thinking_level_does_not_clamp_against_stale_shared_model() {
1097 let runtime = RuntimeBuilder::current_thread()
1098 .build()
1099 .expect("runtime build");
1100
1101 runtime.block_on(async {
1102 let session = Arc::new(Mutex::new(Session::in_memory()));
1103 {
1104 let cx = Cx::for_request();
1105 let mut guard = session.lock(&cx).await.expect("lock session");
1106 guard.append_model_change("anthropic".to_string(), "claude-sonnet-4-5".to_string());
1107 guard.set_model_header(
1108 Some("anthropic".to_string()),
1109 Some("claude-sonnet-4-5".to_string()),
1110 None,
1111 );
1112 }
1113
1114 let ext_session = InteractiveExtensionSession {
1115 session: Arc::clone(&session),
1116 model_entry: Arc::new(StdMutex::new(dummy_model_entry())),
1117 is_streaming: Arc::new(AtomicBool::new(false)),
1118 is_compacting: Arc::new(AtomicBool::new(false)),
1119 config: Config::default(),
1120 save_enabled: false,
1121 };
1122
1123 ext_session
1124 .set_thinking_level("high".to_string())
1125 .await
1126 .expect("thinking update should preserve requested level");
1127
1128 let cx = Cx::for_request();
1129 let guard = session.lock(&cx).await.expect("lock session");
1130 assert_eq!(guard.header.thinking_level.as_deref(), Some("high"));
1131 let thinking_changes = guard
1132 .entries_for_current_path()
1133 .iter()
1134 .filter(|entry| {
1135 matches!(entry, crate::session::SessionEntry::ThinkingLevelChange(_))
1136 })
1137 .count();
1138 assert_eq!(thinking_changes, 1);
1139 });
1140 }
1141
1142 #[test]
1143 fn set_model_avoids_duplicate_history_for_same_target() {
1144 let runtime = RuntimeBuilder::current_thread()
1145 .build()
1146 .expect("runtime build");
1147
1148 runtime.block_on(async {
1149 let session = Arc::new(Mutex::new(Session::in_memory()));
1150 let ext_session = InteractiveExtensionSession {
1151 session: Arc::clone(&session),
1152 model_entry: Arc::new(StdMutex::new(dummy_model_entry())),
1153 is_streaming: Arc::new(AtomicBool::new(false)),
1154 is_compacting: Arc::new(AtomicBool::new(false)),
1155 config: Config::default(),
1156 save_enabled: false,
1157 };
1158
1159 ext_session
1160 .set_model("anthropic".to_string(), "claude-sonnet-4-5".to_string())
1161 .await
1162 .expect("first model update");
1163 ext_session
1164 .set_model("anthropic".to_string(), "claude-sonnet-4-5".to_string())
1165 .await
1166 .expect("second model update");
1167
1168 let cx = Cx::for_request();
1169 let guard = session.lock(&cx).await.expect("lock session");
1170 let model_changes = guard
1171 .entries_for_current_path()
1172 .iter()
1173 .filter(|entry| matches!(entry, crate::session::SessionEntry::ModelChange(_)))
1174 .count();
1175 assert_eq!(model_changes, 1);
1176 });
1177 }
1178
1179 #[test]
1180 fn set_model_dedupes_provider_alias_targets_without_rewriting_current_branch_state() {
1181 let runtime = RuntimeBuilder::current_thread()
1182 .build()
1183 .expect("runtime build");
1184
1185 runtime.block_on(async {
1186 let session = Arc::new(Mutex::new(Session::in_memory()));
1187 {
1188 let cx = Cx::for_request();
1189 let mut guard = session.lock(&cx).await.expect("lock session");
1190 guard.append_model_change("google".to_string(), "gemini-2.5-pro".to_string());
1191 guard.set_model_header(
1192 Some("google".to_string()),
1193 Some("gemini-2.5-pro".to_string()),
1194 None,
1195 );
1196 }
1197 let ext_session = InteractiveExtensionSession {
1198 session: Arc::clone(&session),
1199 model_entry: Arc::new(StdMutex::new(dummy_model_entry())),
1200 is_streaming: Arc::new(AtomicBool::new(false)),
1201 is_compacting: Arc::new(AtomicBool::new(false)),
1202 config: Config::default(),
1203 save_enabled: false,
1204 };
1205
1206 ext_session
1207 .set_model("gemini".to_string(), "gemini-2.5-pro".to_string())
1208 .await
1209 .expect("alias target should dedupe");
1210
1211 let cx = Cx::for_request();
1212 let guard = session.lock(&cx).await.expect("lock session");
1213 let branch = guard.entries_for_current_path();
1214 let model_changes: Vec<_> = branch
1215 .iter()
1216 .filter_map(|entry| {
1217 if let crate::session::SessionEntry::ModelChange(change) = entry {
1218 Some((change.provider.as_str(), change.model_id.as_str()))
1219 } else {
1220 None
1221 }
1222 })
1223 .collect();
1224
1225 assert_eq!(model_changes, vec![("google", "gemini-2.5-pro")]);
1226 assert_eq!(guard.header.provider.as_deref(), Some("google"));
1227 assert_eq!(guard.header.model_id.as_deref(), Some("gemini-2.5-pro"));
1228 });
1229 }
1230
1231 #[test]
1232 fn branch_local_model_and_thinking_state_follow_current_path() {
1233 let runtime = RuntimeBuilder::current_thread()
1234 .build()
1235 .expect("runtime build");
1236
1237 runtime.block_on(async {
1238 let mut session_state = Session::in_memory();
1239 let root_id = session_state.append_message(SessionMessage::User {
1240 content: crate::model::UserContent::Text("root".to_string()),
1241 timestamp: Some(0),
1242 });
1243 session_state.append_model_change("openai".to_string(), "gpt-4o".to_string());
1244 let branch_a_thinking = session_state.append_thinking_level_change("low".to_string());
1245 session_state.set_model_header(
1246 Some("openai".to_string()),
1247 Some("gpt-4o".to_string()),
1248 Some("low".to_string()),
1249 );
1250
1251 assert!(session_state.create_branch_from(&root_id));
1252 session_state
1253 .append_model_change("anthropic".to_string(), "claude-sonnet-4-5".to_string());
1254 session_state.append_thinking_level_change("high".to_string());
1255 session_state.set_model_header(
1256 Some("anthropic".to_string()),
1257 Some("claude-sonnet-4-5".to_string()),
1258 Some("high".to_string()),
1259 );
1260
1261 assert!(session_state.navigate_to(&branch_a_thinking));
1262
1263 let session = Arc::new(Mutex::new(session_state));
1264 let ext_session = InteractiveExtensionSession {
1265 session,
1266 model_entry: Arc::new(StdMutex::new(dummy_model_entry())),
1267 is_streaming: Arc::new(AtomicBool::new(false)),
1268 is_compacting: Arc::new(AtomicBool::new(false)),
1269 config: Config::default(),
1270 save_enabled: false,
1271 };
1272
1273 let state = ext_session.get_state().await;
1274 let (provider, model_id) = ext_session.get_model().await;
1275 let thinking_level = ext_session.get_thinking_level().await;
1276
1277 assert_eq!(provider.as_deref(), Some("openai"));
1278 assert_eq!(model_id.as_deref(), Some("gpt-4o"));
1279 assert_eq!(thinking_level.as_deref(), Some("low"));
1280 assert_eq!(state["model"]["provider"], "openai");
1281 assert_eq!(state["model"]["id"], "gpt-4o");
1282 assert_eq!(state["thinkingLevel"], "low");
1283 });
1284 }
1285
1286 #[test]
1287 fn branch_without_overrides_does_not_inherit_stale_header_state() {
1288 let runtime = RuntimeBuilder::current_thread()
1289 .build()
1290 .expect("runtime build");
1291
1292 runtime.block_on(async {
1293 let mut session_state = Session::in_memory();
1294 let root_id = session_state.append_message(SessionMessage::User {
1295 content: crate::model::UserContent::Text("root".to_string()),
1296 timestamp: Some(0),
1297 });
1298 let branch_a_tip = session_state.append_message(SessionMessage::User {
1299 content: crate::model::UserContent::Text("branch-a".to_string()),
1300 timestamp: Some(0),
1301 });
1302
1303 assert!(session_state.create_branch_from(&root_id));
1304 session_state
1305 .append_model_change("anthropic".to_string(), "claude-sonnet-4-5".to_string());
1306 session_state.append_thinking_level_change("high".to_string());
1307 session_state.set_model_header(
1308 Some("anthropic".to_string()),
1309 Some("claude-sonnet-4-5".to_string()),
1310 Some("high".to_string()),
1311 );
1312
1313 assert!(session_state.navigate_to(&branch_a_tip));
1314
1315 let session = Arc::new(Mutex::new(session_state));
1316 let ext_session = InteractiveExtensionSession {
1317 session,
1318 model_entry: Arc::new(StdMutex::new(dummy_model_entry())),
1319 is_streaming: Arc::new(AtomicBool::new(false)),
1320 is_compacting: Arc::new(AtomicBool::new(false)),
1321 config: Config::default(),
1322 save_enabled: false,
1323 };
1324
1325 let state = ext_session.get_state().await;
1326 let (provider, model_id) = ext_session.get_model().await;
1327 let thinking_level = ext_session.get_thinking_level().await;
1328
1329 assert!(provider.is_none());
1330 assert!(model_id.is_none());
1331 assert!(thinking_level.is_none());
1332 assert!(state["model"].is_null());
1333 assert_eq!(state["thinkingLevel"], "off");
1334 });
1335 }
1336
1337 #[test]
1338 fn set_model_and_thinking_dedupe_on_switched_branch() {
1339 let runtime = RuntimeBuilder::current_thread()
1340 .build()
1341 .expect("runtime build");
1342
1343 runtime.block_on(async {
1344 let mut session_state = Session::in_memory();
1345 let root_id = session_state.append_message(SessionMessage::User {
1346 content: crate::model::UserContent::Text("root".to_string()),
1347 timestamp: Some(0),
1348 });
1349 session_state.append_model_change("openai".to_string(), "gpt-4o".to_string());
1350 let branch_a_thinking = session_state.append_thinking_level_change("low".to_string());
1351 session_state.set_model_header(
1352 Some("openai".to_string()),
1353 Some("gpt-4o".to_string()),
1354 Some("low".to_string()),
1355 );
1356
1357 assert!(session_state.create_branch_from(&root_id));
1358 session_state
1359 .append_model_change("anthropic".to_string(), "claude-sonnet-4-5".to_string());
1360 session_state.append_thinking_level_change("high".to_string());
1361 session_state.set_model_header(
1362 Some("anthropic".to_string()),
1363 Some("claude-sonnet-4-5".to_string()),
1364 Some("high".to_string()),
1365 );
1366
1367 assert!(session_state.navigate_to(&branch_a_thinking));
1368
1369 let session = Arc::new(Mutex::new(session_state));
1370 let ext_session = InteractiveExtensionSession {
1371 session: Arc::clone(&session),
1372 model_entry: Arc::new(StdMutex::new(dummy_model_entry())),
1373 is_streaming: Arc::new(AtomicBool::new(false)),
1374 is_compacting: Arc::new(AtomicBool::new(false)),
1375 config: Config::default(),
1376 save_enabled: false,
1377 };
1378
1379 ext_session
1380 .set_model("openai".to_string(), "gpt-4o".to_string())
1381 .await
1382 .expect("same-branch model should dedupe");
1383 ext_session
1384 .set_thinking_level("low".to_string())
1385 .await
1386 .expect("same-branch thinking should dedupe");
1387
1388 let cx = Cx::for_request();
1389 let guard = session.lock(&cx).await.expect("lock session");
1390 let branch = guard.entries_for_current_path();
1391 let model_changes = branch
1392 .iter()
1393 .filter(|entry| matches!(entry, crate::session::SessionEntry::ModelChange(_)))
1394 .count();
1395 let thinking_changes = branch
1396 .iter()
1397 .filter(|entry| {
1398 matches!(entry, crate::session::SessionEntry::ThinkingLevelChange(_))
1399 })
1400 .count();
1401
1402 assert_eq!(model_changes, 1);
1403 assert_eq!(thinking_changes, 1);
1404 });
1405 }
1406
1407 #[test]
1408 fn get_state_reports_configured_queue_modes() {
1409 let runtime = RuntimeBuilder::current_thread()
1410 .build()
1411 .expect("runtime build");
1412
1413 runtime.block_on(async {
1414 let session = Arc::new(Mutex::new(Session::in_memory()));
1415 let ext_session = InteractiveExtensionSession {
1416 session,
1417 model_entry: Arc::new(StdMutex::new(dummy_model_entry())),
1418 is_streaming: Arc::new(AtomicBool::new(false)),
1419 is_compacting: Arc::new(AtomicBool::new(false)),
1420 config: Config {
1421 steering_mode: Some("all".to_string()),
1422 follow_up_mode: Some("one-at-a-time".to_string()),
1423 ..Config::default()
1424 },
1425 save_enabled: false,
1426 };
1427
1428 let state = ext_session.get_state().await;
1429 assert_eq!(state["steeringMode"], "all");
1430 assert_eq!(state["followUpMode"], "one-at-a-time");
1431 });
1432 }
1433}