1use std::collections::{HashMap, HashSet, VecDeque};
2use std::time::{Instant, SystemTime, UNIX_EPOCH};
3
4use chio_core::crypto::{canonical_json_bytes, sha256_hex};
5use chio_core::session::{
6 CompletionResult, CreateElicitationOperation, NormalizedRoot, OperationContext, OperationKind,
7 OperationTerminalState, ProgressToken, PromptDefinition, PromptResult, RequestId,
8 RequestOwnershipSnapshot, ResourceContent, ResourceDefinition, ResourceTemplateDefinition,
9 RootDefinition, SessionAnchorReference, SessionAuthContext, SessionId,
10};
11use chio_core::{AgentId, CapabilityToken};
12
13use crate::{ToolCallResponse, ToolServerEvent};
14use chio_core::receipt::ChioReceipt;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum SessionState {
19 Initializing,
20 Ready,
21 Draining,
22 Closed,
23}
24
25impl SessionState {
26 pub fn as_str(self) -> &'static str {
27 match self {
28 Self::Initializing => "initializing",
29 Self::Ready => "ready",
30 Self::Draining => "draining",
31 Self::Closed => "closed",
32 }
33 }
34}
35
36#[derive(Debug, Clone, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
38pub struct PeerCapabilities {
39 pub supports_progress: bool,
40 pub supports_cancellation: bool,
41 pub supports_subscriptions: bool,
42 pub supports_chio_tool_streaming: bool,
43 pub supports_roots: bool,
44 pub roots_list_changed: bool,
45 pub supports_sampling: bool,
46 pub sampling_context: bool,
47 pub sampling_tools: bool,
48 pub supports_elicitation: bool,
49 pub elicitation_form: bool,
50 pub elicitation_url: bool,
51}
52
53#[derive(Debug, Clone)]
55pub struct InflightRequest {
56 pub request_id: RequestId,
57 pub parent_request_id: Option<RequestId>,
58 pub operation_kind: OperationKind,
59 pub session_anchor_id: String,
60 pub started_at: Instant,
61 pub progress_token: Option<ProgressToken>,
62 pub cancellation_requested: bool,
63 pub cancellable: bool,
64}
65
66impl InflightRequest {
67 pub fn ownership(&self) -> RequestOwnershipSnapshot {
68 RequestOwnershipSnapshot::request_owned()
69 }
70}
71
72#[derive(Debug, Clone, Default)]
74pub struct InflightRegistry {
75 requests: HashMap<RequestId, InflightRequest>,
76}
77
78impl InflightRegistry {
79 pub fn track(
80 &mut self,
81 context: &OperationContext,
82 operation_kind: OperationKind,
83 session_anchor_id: &str,
84 cancellable: bool,
85 ) -> Result<(), SessionError> {
86 if self.requests.contains_key(&context.request_id) {
87 return Err(SessionError::DuplicateInflightRequest {
88 request_id: context.request_id.clone(),
89 });
90 }
91
92 self.requests.insert(
93 context.request_id.clone(),
94 InflightRequest {
95 request_id: context.request_id.clone(),
96 parent_request_id: context.parent_request_id.clone(),
97 operation_kind,
98 session_anchor_id: session_anchor_id.to_string(),
99 started_at: Instant::now(),
100 progress_token: context.progress_token.clone(),
101 cancellation_requested: false,
102 cancellable,
103 },
104 );
105 Ok(())
106 }
107
108 pub fn complete(&mut self, request_id: &RequestId) -> Result<InflightRequest, SessionError> {
109 self.requests
110 .remove(request_id)
111 .ok_or_else(|| SessionError::RequestNotInflight {
112 request_id: request_id.clone(),
113 })
114 }
115
116 pub fn mark_cancellation_requested(
117 &mut self,
118 request_id: &RequestId,
119 ) -> Result<(), SessionError> {
120 let request =
121 self.requests
122 .get_mut(request_id)
123 .ok_or_else(|| SessionError::RequestNotInflight {
124 request_id: request_id.clone(),
125 })?;
126
127 if !request.cancellable {
128 return Err(SessionError::RequestNotCancellable {
129 request_id: request_id.clone(),
130 });
131 }
132
133 request.cancellation_requested = true;
134 Ok(())
135 }
136
137 pub fn get(&self, request_id: &RequestId) -> Option<&InflightRequest> {
138 self.requests.get(request_id)
139 }
140
141 pub fn len(&self) -> usize {
142 self.requests.len()
143 }
144
145 pub fn is_empty(&self) -> bool {
146 self.requests.is_empty()
147 }
148
149 pub fn clear(&mut self) {
150 self.requests.clear();
151 }
152}
153
154#[derive(Debug, Clone, PartialEq, Eq, Hash)]
155enum SubscriptionSubject {
156 Resource(String),
157}
158
159#[derive(Debug, Clone, Default)]
161pub struct SubscriptionRegistry {
162 subscriptions: HashSet<SubscriptionSubject>,
163}
164
165#[derive(Debug, Clone, PartialEq, Eq)]
166pub enum LateSessionEvent {
167 ElicitationCompleted {
168 elicitation_id: String,
169 related_task_id: Option<String>,
170 },
171 ResourceUpdated {
172 uri: String,
173 },
174 ResourcesListChanged,
175 ToolsListChanged,
176 PromptsListChanged,
177}
178
179#[derive(Debug, Clone)]
180struct PendingUrlElicitation {
181 related_task_id: Option<String>,
182}
183
184impl SubscriptionRegistry {
185 pub fn subscribe_resource(&mut self, uri: impl Into<String>) {
186 self.subscriptions
187 .insert(SubscriptionSubject::Resource(uri.into()));
188 }
189
190 pub fn unsubscribe_resource(&mut self, uri: &str) {
191 self.subscriptions
192 .remove(&SubscriptionSubject::Resource(uri.to_string()));
193 }
194
195 pub fn contains_resource(&self, uri: &str) -> bool {
196 self.subscriptions
197 .contains(&SubscriptionSubject::Resource(uri.to_string()))
198 }
199
200 pub fn len(&self) -> usize {
201 self.subscriptions.len()
202 }
203
204 pub fn is_empty(&self) -> bool {
205 self.subscriptions.is_empty()
206 }
207
208 pub fn clear(&mut self) {
209 self.subscriptions.clear();
210 }
211}
212
213const TERMINAL_HISTORY_LIMIT: usize = 256;
214
215#[derive(Debug, Clone)]
217pub struct TerminalRegistry {
218 states: HashMap<RequestId, OperationTerminalState>,
219 order: VecDeque<RequestId>,
220 limit: usize,
221}
222
223impl Default for TerminalRegistry {
224 fn default() -> Self {
225 Self {
226 states: HashMap::new(),
227 order: VecDeque::new(),
228 limit: TERMINAL_HISTORY_LIMIT,
229 }
230 }
231}
232
233impl TerminalRegistry {
234 pub fn record(&mut self, request_id: RequestId, state: OperationTerminalState) {
235 if !self.states.contains_key(&request_id) {
236 self.order.push_back(request_id.clone());
237 }
238 self.states.insert(request_id, state);
239
240 while self.order.len() > self.limit {
241 if let Some(oldest) = self.order.pop_front() {
242 self.states.remove(&oldest);
243 }
244 }
245 }
246
247 pub fn get(&self, request_id: &RequestId) -> Option<&OperationTerminalState> {
248 self.states.get(request_id)
249 }
250
251 pub fn len(&self) -> usize {
252 self.states.len()
253 }
254
255 pub fn is_empty(&self) -> bool {
256 self.states.is_empty()
257 }
258}
259
260#[derive(Debug, thiserror::Error, PartialEq, Eq)]
262pub enum SessionError {
263 #[error("invalid session transition from {from} to {to}")]
264 InvalidTransition {
265 from: &'static str,
266 to: &'static str,
267 },
268
269 #[error("session {session_id} cannot handle {operation} while {state}")]
270 OperationNotAllowed {
271 session_id: SessionId,
272 operation: &'static str,
273 state: &'static str,
274 },
275
276 #[error("operation context session {actual} does not match runtime session {expected}")]
277 ContextSessionMismatch {
278 expected: SessionId,
279 actual: SessionId,
280 },
281
282 #[error("operation context agent {actual} does not match session agent {expected}")]
283 ContextAgentMismatch { expected: AgentId, actual: AgentId },
284
285 #[error("request {request_id} is already in flight")]
286 DuplicateInflightRequest { request_id: RequestId },
287
288 #[error("request {request_id} already has authoritative lineage in this session")]
289 DuplicateRequestLineage { request_id: RequestId },
290
291 #[error("request {request_id} is not in flight")]
292 RequestNotInflight { request_id: RequestId },
293
294 #[error("request {request_id} is not cancellable")]
295 RequestNotCancellable { request_id: RequestId },
296
297 #[error("parent request {parent_request_id} is not in flight for child request {request_id}")]
298 ParentRequestNotInflight {
299 request_id: RequestId,
300 parent_request_id: RequestId,
301 },
302
303 #[error(
304 "parent request {parent_request_id} for child request {request_id} belongs to stale session anchor {parent_session_anchor_id}, current anchor is {current_session_anchor_id}"
305 )]
306 ParentRequestAnchorMismatch {
307 request_id: RequestId,
308 parent_request_id: RequestId,
309 parent_session_anchor_id: String,
310 current_session_anchor_id: String,
311 },
312}
313
314#[derive(Debug, Clone, PartialEq, Eq)]
315pub struct SessionAnchorState {
316 id: String,
317 auth_epoch: u64,
318 auth_context_hash: String,
319 issued_at: u64,
320}
321
322impl SessionAnchorState {
323 fn new(session_id: &SessionId, auth_context: &SessionAuthContext, auth_epoch: u64) -> Self {
324 let auth_context_hash = auth_context_hash(auth_context);
325 let hash_prefix = &auth_context_hash[..12.min(auth_context_hash.len())];
326 Self {
327 id: format!("{session_id}:anchor:{auth_epoch}:{hash_prefix}"),
328 auth_epoch,
329 auth_context_hash,
330 issued_at: current_unix_timestamp(),
331 }
332 }
333
334 pub fn id(&self) -> &str {
335 &self.id
336 }
337
338 pub fn auth_epoch(&self) -> u64 {
339 self.auth_epoch
340 }
341
342 pub fn auth_context_hash(&self) -> &str {
343 &self.auth_context_hash
344 }
345
346 pub fn issued_at(&self) -> u64 {
347 self.issued_at
348 }
349
350 pub fn reference(&self) -> SessionAnchorReference {
351 SessionAnchorReference::new(self.id.clone(), self.auth_context_hash.clone())
352 }
353}
354
355#[derive(Debug, Clone, PartialEq, Eq)]
356pub struct RequestLineageRecord {
357 pub request_id: RequestId,
358 pub session_anchor_id: String,
359 pub auth_epoch: u64,
360 pub parent_request_id: Option<RequestId>,
361 pub operation_kind: OperationKind,
362 pub started_at: u64,
363 pub terminal_state: Option<OperationTerminalState>,
364}
365
366#[derive(Debug, Clone)]
368pub struct Session {
369 id: SessionId,
370 agent_id: AgentId,
371 state: SessionState,
372 session_anchor: SessionAnchorState,
373 auth_context: SessionAuthContext,
374 peer_capabilities: PeerCapabilities,
375 roots: Vec<RootDefinition>,
376 normalized_roots: Vec<NormalizedRoot>,
377 issued_capabilities: Vec<CapabilityToken>,
378 inflight: InflightRegistry,
379 subscriptions: SubscriptionRegistry,
380 terminal: TerminalRegistry,
381 request_lineage: HashMap<RequestId, RequestLineageRecord>,
382 pending_url_elicitations: HashMap<String, PendingUrlElicitation>,
383 late_events: VecDeque<LateSessionEvent>,
384}
385
386impl Session {
387 pub fn new(
388 id: SessionId,
389 agent_id: AgentId,
390 issued_capabilities: Vec<CapabilityToken>,
391 ) -> Self {
392 let auth_context = SessionAuthContext::in_process_anonymous();
393 let session_anchor = SessionAnchorState::new(&id, &auth_context, 0);
394 Self {
395 id,
396 agent_id,
397 state: SessionState::Initializing,
398 session_anchor,
399 auth_context,
400 peer_capabilities: PeerCapabilities::default(),
401 roots: Vec::new(),
402 normalized_roots: Vec::new(),
403 issued_capabilities,
404 inflight: InflightRegistry::default(),
405 subscriptions: SubscriptionRegistry::default(),
406 terminal: TerminalRegistry::default(),
407 request_lineage: HashMap::new(),
408 pending_url_elicitations: HashMap::new(),
409 late_events: VecDeque::new(),
410 }
411 }
412
413 pub fn id(&self) -> &SessionId {
414 &self.id
415 }
416
417 pub fn agent_id(&self) -> &str {
418 &self.agent_id
419 }
420
421 pub fn state(&self) -> SessionState {
422 self.state
423 }
424
425 pub fn auth_context(&self) -> &SessionAuthContext {
426 &self.auth_context
427 }
428
429 pub fn session_anchor(&self) -> &SessionAnchorState {
430 &self.session_anchor
431 }
432
433 pub fn request_lineage(&self, request_id: &RequestId) -> Option<&RequestLineageRecord> {
434 self.request_lineage.get(request_id)
435 }
436
437 pub fn peer_capabilities(&self) -> &PeerCapabilities {
438 &self.peer_capabilities
439 }
440
441 pub fn capabilities(&self) -> &[CapabilityToken] {
442 &self.issued_capabilities
443 }
444
445 pub fn roots(&self) -> &[RootDefinition] {
446 &self.roots
447 }
448
449 pub fn normalized_roots(&self) -> &[NormalizedRoot] {
450 &self.normalized_roots
451 }
452
453 pub fn enforceable_filesystem_roots(&self) -> impl Iterator<Item = &NormalizedRoot> {
454 self.normalized_roots
455 .iter()
456 .filter(|root| root.is_enforceable_filesystem())
457 }
458
459 pub fn inflight(&self) -> &InflightRegistry {
460 &self.inflight
461 }
462
463 pub fn subscriptions(&self) -> &SubscriptionRegistry {
464 &self.subscriptions
465 }
466
467 pub fn terminal(&self) -> &TerminalRegistry {
468 &self.terminal
469 }
470
471 pub fn register_pending_url_elicitation(
472 &mut self,
473 elicitation_id: impl Into<String>,
474 related_task_id: Option<String>,
475 ) {
476 self.pending_url_elicitations.insert(
477 elicitation_id.into(),
478 PendingUrlElicitation { related_task_id },
479 );
480 }
481
482 pub fn register_required_url_elicitations(
483 &mut self,
484 elicitations: &[CreateElicitationOperation],
485 related_task_id: Option<&str>,
486 ) {
487 for elicitation in elicitations {
488 let CreateElicitationOperation::Url { elicitation_id, .. } = elicitation else {
489 continue;
490 };
491 self.register_pending_url_elicitation(
492 elicitation_id.clone(),
493 related_task_id.map(ToString::to_string),
494 );
495 }
496 }
497
498 pub fn queue_late_event(&mut self, event: LateSessionEvent) {
499 self.late_events.push_back(event);
500 }
501
502 pub fn take_late_events(&mut self) -> Vec<LateSessionEvent> {
503 self.late_events.drain(..).collect()
504 }
505
506 pub fn queue_tool_server_event(&mut self, event: ToolServerEvent) {
507 match event {
508 ToolServerEvent::ElicitationCompleted { elicitation_id } => {
509 let Some(pending) = self.pending_url_elicitations.remove(&elicitation_id) else {
510 return;
511 };
512 self.queue_late_event(LateSessionEvent::ElicitationCompleted {
513 elicitation_id,
514 related_task_id: pending.related_task_id,
515 });
516 }
517 ToolServerEvent::ResourceUpdated { uri } => {
518 if self.is_resource_subscribed(&uri) {
519 self.queue_late_event(LateSessionEvent::ResourceUpdated { uri });
520 }
521 }
522 ToolServerEvent::ResourcesListChanged => {
523 self.queue_late_event(LateSessionEvent::ResourcesListChanged);
524 }
525 ToolServerEvent::ToolsListChanged => {
526 self.queue_late_event(LateSessionEvent::ToolsListChanged);
527 }
528 ToolServerEvent::PromptsListChanged => {
529 self.queue_late_event(LateSessionEvent::PromptsListChanged);
530 }
531 }
532 }
533
534 pub fn queue_elicitation_completion(&mut self, elicitation_id: &str) {
535 let Some(pending) = self.pending_url_elicitations.remove(elicitation_id) else {
536 return;
537 };
538 self.queue_late_event(LateSessionEvent::ElicitationCompleted {
539 elicitation_id: elicitation_id.to_string(),
540 related_task_id: pending.related_task_id,
541 });
542 }
543
544 pub fn subscribe_resource(&mut self, uri: impl Into<String>) {
545 self.subscriptions.subscribe_resource(uri);
546 }
547
548 pub fn unsubscribe_resource(&mut self, uri: &str) {
549 self.subscriptions.unsubscribe_resource(uri);
550 }
551
552 pub fn is_resource_subscribed(&self, uri: &str) -> bool {
553 self.subscriptions.contains_resource(uri)
554 }
555
556 pub fn set_auth_context(&mut self, auth_context: SessionAuthContext) -> bool {
557 let rotated = self.auth_context != auth_context;
558 if rotated {
559 let next_epoch = self.session_anchor.auth_epoch.saturating_add(1);
560 self.session_anchor = SessionAnchorState::new(&self.id, &auth_context, next_epoch);
561 }
562 self.auth_context = auth_context;
563 rotated
564 }
565
566 pub fn set_peer_capabilities(&mut self, peer_capabilities: PeerCapabilities) {
567 self.peer_capabilities = peer_capabilities;
568 }
569
570 pub fn replace_roots(&mut self, roots: Vec<RootDefinition>) {
571 self.normalized_roots = roots
572 .iter()
573 .map(RootDefinition::normalize_for_runtime)
574 .collect();
575 self.roots = roots;
576 }
577
578 pub fn activate(&mut self) -> Result<(), SessionError> {
579 self.transition(SessionState::Ready)
580 }
581
582 pub fn begin_draining(&mut self) -> Result<(), SessionError> {
583 self.transition(SessionState::Draining)
584 }
585
586 pub fn close(&mut self) -> Result<(), SessionError> {
587 self.transition(SessionState::Closed)?;
588 self.inflight.clear();
589 self.subscriptions.clear();
590 self.roots.clear();
591 self.normalized_roots.clear();
592 self.pending_url_elicitations.clear();
593 self.late_events.clear();
594 Ok(())
595 }
596
597 pub fn ensure_operation_allowed(&self, operation: OperationKind) -> Result<(), SessionError> {
598 let allowed = match self.state {
599 SessionState::Initializing => matches!(
600 operation,
601 OperationKind::ListCapabilities | OperationKind::Heartbeat
602 ),
603 SessionState::Ready => true,
604 SessionState::Draining => matches!(
605 operation,
606 OperationKind::ListCapabilities | OperationKind::Heartbeat
607 ),
608 SessionState::Closed => false,
609 };
610
611 if allowed {
612 Ok(())
613 } else {
614 Err(SessionError::OperationNotAllowed {
615 session_id: self.id.clone(),
616 operation: operation.as_str(),
617 state: self.state.as_str(),
618 })
619 }
620 }
621
622 pub fn track_request(
623 &mut self,
624 context: &OperationContext,
625 operation_kind: OperationKind,
626 cancellable: bool,
627 ) -> Result<(), SessionError> {
628 self.validate_context(context)?;
629 if let Some(parent_request_id) = &context.parent_request_id {
630 self.validate_parent_request_lineage(&context.request_id, parent_request_id)?;
631 }
632 if self.inflight.get(&context.request_id).is_some() {
633 return Err(SessionError::DuplicateInflightRequest {
634 request_id: context.request_id.clone(),
635 });
636 }
637 if self.request_lineage.contains_key(&context.request_id) {
638 return Err(SessionError::DuplicateRequestLineage {
639 request_id: context.request_id.clone(),
640 });
641 }
642 self.inflight.track(
643 context,
644 operation_kind,
645 self.session_anchor.id(),
646 cancellable,
647 )?;
648 self.request_lineage.insert(
649 context.request_id.clone(),
650 RequestLineageRecord {
651 request_id: context.request_id.clone(),
652 session_anchor_id: self.session_anchor.id().to_string(),
653 auth_epoch: self.session_anchor.auth_epoch(),
654 parent_request_id: context.parent_request_id.clone(),
655 operation_kind,
656 started_at: current_unix_timestamp(),
657 terminal_state: None,
658 },
659 );
660 Ok(())
661 }
662
663 pub fn complete_request(
664 &mut self,
665 request_id: &RequestId,
666 ) -> Result<InflightRequest, SessionError> {
667 self.complete_request_with_terminal_state(request_id, OperationTerminalState::Completed)
668 }
669
670 pub fn complete_request_with_terminal_state(
671 &mut self,
672 request_id: &RequestId,
673 terminal_state: OperationTerminalState,
674 ) -> Result<InflightRequest, SessionError> {
675 let inflight = self.inflight.complete(request_id)?;
676 self.terminal
677 .record(request_id.clone(), terminal_state.clone());
678 if let Some(lineage) = self.request_lineage.get_mut(request_id) {
679 lineage.terminal_state = Some(terminal_state);
680 }
681 Ok(inflight)
682 }
683
684 pub fn request_cancellation(&mut self, request_id: &RequestId) -> Result<(), SessionError> {
685 self.inflight.mark_cancellation_requested(request_id)
686 }
687
688 pub fn validate_parent_request_lineage(
689 &self,
690 request_id: &RequestId,
691 parent_request_id: &RequestId,
692 ) -> Result<&RequestLineageRecord, SessionError> {
693 let Some(parent_inflight) = self.inflight.get(parent_request_id) else {
694 return Err(SessionError::ParentRequestNotInflight {
695 request_id: request_id.clone(),
696 parent_request_id: parent_request_id.clone(),
697 });
698 };
699 let Some(parent_lineage) = self.request_lineage.get(parent_request_id) else {
700 return Err(SessionError::ParentRequestNotInflight {
701 request_id: request_id.clone(),
702 parent_request_id: parent_request_id.clone(),
703 });
704 };
705 if parent_lineage.session_anchor_id != self.session_anchor.id() {
706 return Err(SessionError::ParentRequestAnchorMismatch {
707 request_id: request_id.clone(),
708 parent_request_id: parent_request_id.clone(),
709 parent_session_anchor_id: parent_inflight.session_anchor_id.clone(),
710 current_session_anchor_id: self.session_anchor.id().to_string(),
711 });
712 }
713 Ok(parent_lineage)
714 }
715
716 fn transition(&mut self, next: SessionState) -> Result<(), SessionError> {
717 let valid = match (self.state, next) {
718 (SessionState::Initializing, SessionState::Ready)
719 | (SessionState::Initializing, SessionState::Closed)
720 | (SessionState::Ready, SessionState::Draining)
721 | (SessionState::Ready, SessionState::Closed)
722 | (SessionState::Draining, SessionState::Closed) => true,
723 _ if self.state == next => true,
724 _ => false,
725 };
726
727 if !valid {
728 return Err(SessionError::InvalidTransition {
729 from: self.state.as_str(),
730 to: next.as_str(),
731 });
732 }
733
734 self.state = next;
735 Ok(())
736 }
737
738 pub fn validate_context(&self, context: &OperationContext) -> Result<(), SessionError> {
739 if context.session_id != self.id {
740 return Err(SessionError::ContextSessionMismatch {
741 expected: self.id.clone(),
742 actual: context.session_id.clone(),
743 });
744 }
745
746 if context.agent_id != self.agent_id {
747 return Err(SessionError::ContextAgentMismatch {
748 expected: self.agent_id.clone(),
749 actual: context.agent_id.clone(),
750 });
751 }
752
753 Ok(())
754 }
755}
756
757fn current_unix_timestamp() -> u64 {
758 SystemTime::now()
759 .duration_since(UNIX_EPOCH)
760 .map(|duration| duration.as_secs())
761 .unwrap_or(0)
762}
763
764fn auth_context_hash(auth_context: &SessionAuthContext) -> String {
765 canonical_json_bytes(auth_context)
766 .map(|bytes| sha256_hex(&bytes))
767 .unwrap_or_else(|_| "session-auth-context-hash-unavailable".to_string())
768}
769
770#[derive(Debug)]
772pub enum SessionOperationResponse {
773 ToolCall(ToolCallResponse),
774 RootList {
775 roots: Vec<RootDefinition>,
776 },
777 ResourceList {
778 resources: Vec<ResourceDefinition>,
779 },
780 ResourceRead {
781 contents: Vec<ResourceContent>,
782 },
783 ResourceReadDenied {
784 receipt: ChioReceipt,
785 },
786 ResourceTemplateList {
787 templates: Vec<ResourceTemplateDefinition>,
788 },
789 PromptList {
790 prompts: Vec<PromptDefinition>,
791 },
792 PromptGet {
793 prompt: PromptResult,
794 },
795 Completion {
796 completion: CompletionResult,
797 },
798 CapabilityList {
799 capabilities: Vec<CapabilityToken>,
800 },
801 Heartbeat,
802}
803
804#[cfg(test)]
805#[allow(clippy::expect_used, clippy::unwrap_used)]
806mod tests {
807 use super::*;
808
809 fn make_context(request_id: &str) -> OperationContext {
810 OperationContext {
811 session_id: SessionId::new("sess-1"),
812 request_id: RequestId::new(request_id),
813 agent_id: "agent-1".to_string(),
814 parent_request_id: None,
815 progress_token: Some(ProgressToken::String("progress-1".to_string())),
816 }
817 }
818
819 #[test]
820 fn lifecycle_transitions_cover_ready_draining_closed() {
821 let mut session = Session::new(SessionId::new("sess-1"), "agent-1".to_string(), Vec::new());
822
823 assert_eq!(session.state(), SessionState::Initializing);
824 session.activate().unwrap();
825 assert_eq!(session.state(), SessionState::Ready);
826 session.begin_draining().unwrap();
827 assert_eq!(session.state(), SessionState::Draining);
828 session.close().unwrap();
829 assert_eq!(session.state(), SessionState::Closed);
830 }
831
832 #[test]
833 fn tool_calls_not_allowed_during_initializing_or_draining() {
834 let mut session = Session::new(SessionId::new("sess-1"), "agent-1".to_string(), Vec::new());
835
836 let err = session
837 .ensure_operation_allowed(OperationKind::ToolCall)
838 .unwrap_err();
839 assert!(matches!(err, SessionError::OperationNotAllowed { .. }));
840
841 session.activate().unwrap();
842 session.begin_draining().unwrap();
843
844 let err = session
845 .ensure_operation_allowed(OperationKind::ToolCall)
846 .unwrap_err();
847 assert!(matches!(err, SessionError::OperationNotAllowed { .. }));
848 }
849
850 #[test]
851 fn peer_capabilities_and_roots_are_session_scoped() {
852 let mut session = Session::new(SessionId::new("sess-1"), "agent-1".to_string(), Vec::new());
853
854 session.set_peer_capabilities(PeerCapabilities {
855 supports_progress: false,
856 supports_cancellation: false,
857 supports_subscriptions: false,
858 supports_chio_tool_streaming: false,
859 supports_roots: true,
860 roots_list_changed: true,
861 supports_sampling: true,
862 sampling_context: true,
863 sampling_tools: false,
864 supports_elicitation: false,
865 elicitation_form: false,
866 elicitation_url: false,
867 });
868 session.replace_roots(vec![RootDefinition {
869 uri: "file:///workspace/project".to_string(),
870 name: Some("Project".to_string()),
871 }]);
872
873 assert!(session.peer_capabilities().supports_roots);
874 assert!(session.peer_capabilities().roots_list_changed);
875 assert_eq!(session.roots().len(), 1);
876 assert_eq!(session.roots()[0].uri, "file:///workspace/project");
877 assert_eq!(session.normalized_roots().len(), 1);
878 assert!(matches!(
879 session.normalized_roots()[0],
880 NormalizedRoot::EnforceableFileSystem {
881 ref normalized_path,
882 ..
883 } if normalized_path == "/workspace/project"
884 ));
885 assert_eq!(session.enforceable_filesystem_roots().count(), 1);
886
887 session.close().unwrap();
888 assert!(session.roots().is_empty());
889 assert!(session.normalized_roots().is_empty());
890 }
891
892 #[test]
893 fn mixed_roots_preserve_metadata_without_widening_enforceable_set() {
894 let mut session = Session::new(SessionId::new("sess-1"), "agent-1".to_string(), Vec::new());
895 session.replace_roots(vec![
896 RootDefinition {
897 uri: "file:///workspace/project/src".to_string(),
898 name: Some("Code".to_string()),
899 },
900 RootDefinition {
901 uri: "repo://docs/roadmap".to_string(),
902 name: Some("Roadmap".to_string()),
903 },
904 RootDefinition {
905 uri: "file://remote-host/workspace/project".to_string(),
906 name: Some("Remote".to_string()),
907 },
908 ]);
909
910 assert_eq!(session.normalized_roots().len(), 3);
911 assert!(matches!(
912 session.normalized_roots()[0],
913 NormalizedRoot::EnforceableFileSystem {
914 ref normalized_path,
915 ..
916 } if normalized_path == "/workspace/project/src"
917 ));
918 assert!(matches!(
919 session.normalized_roots()[1],
920 NormalizedRoot::NonFileSystem { ref scheme, .. } if scheme == "repo"
921 ));
922 assert!(matches!(
923 session.normalized_roots()[2],
924 NormalizedRoot::UnenforceableFileSystem { ref reason, .. }
925 if reason == "non_local_file_authority"
926 ));
927 assert_eq!(session.enforceable_filesystem_roots().count(), 1);
928 }
929
930 #[test]
931 fn inflight_registry_tracks_and_completes_requests() {
932 let mut session = Session::new(SessionId::new("sess-1"), "agent-1".to_string(), Vec::new());
933 let context = make_context("req-1");
934
935 session.activate().unwrap();
936 session
937 .track_request(&context, OperationKind::ToolCall, true)
938 .unwrap();
939 assert_eq!(session.inflight().len(), 1);
940
941 let completed = session.complete_request(&context.request_id).unwrap();
942 assert_eq!(completed.request_id, RequestId::new("req-1"));
943 assert_eq!(completed.parent_request_id, None);
944 assert!(completed.cancellable);
945 assert!(session.inflight().is_empty());
946 assert_eq!(
947 session.terminal().get(&context.request_id),
948 Some(&OperationTerminalState::Completed)
949 );
950 }
951
952 #[test]
953 fn child_request_requires_parent_inflight() {
954 let mut session = Session::new(SessionId::new("sess-1"), "agent-1".to_string(), Vec::new());
955 let mut child_context = make_context("req-child");
956 child_context.parent_request_id = Some(RequestId::new("req-parent"));
957
958 session.activate().unwrap();
959 let err = session
960 .track_request(&child_context, OperationKind::CreateMessage, true)
961 .unwrap_err();
962 assert!(matches!(err, SessionError::ParentRequestNotInflight { .. }));
963 }
964
965 #[test]
966 fn duplicate_inflight_request_is_rejected() {
967 let mut session = Session::new(SessionId::new("sess-1"), "agent-1".to_string(), Vec::new());
968 let context = make_context("req-1");
969
970 session.activate().unwrap();
971 session
972 .track_request(&context, OperationKind::ToolCall, true)
973 .unwrap();
974
975 let err = session
976 .track_request(&context, OperationKind::ToolCall, true)
977 .unwrap_err();
978 assert!(matches!(err, SessionError::DuplicateInflightRequest { .. }));
979 }
980
981 #[test]
982 fn cancellation_marks_cancellable_request() {
983 let mut session = Session::new(SessionId::new("sess-1"), "agent-1".to_string(), Vec::new());
984 let context = make_context("req-1");
985
986 session.activate().unwrap();
987 session
988 .track_request(&context, OperationKind::ToolCall, true)
989 .unwrap();
990 session.request_cancellation(&context.request_id).unwrap();
991
992 let inflight = session.inflight().get(&context.request_id).unwrap();
993 assert!(inflight.cancellation_requested);
994 }
995
996 #[test]
997 fn inflight_request_reports_request_owned_semantics() {
998 let mut session = Session::new(SessionId::new("sess-1"), "agent-1".to_string(), Vec::new());
999 let context = make_context("req-1");
1000
1001 session.activate().unwrap();
1002 session
1003 .track_request(&context, OperationKind::ToolCall, true)
1004 .unwrap();
1005
1006 let inflight = session.inflight().get(&context.request_id).unwrap();
1007 let ownership = inflight.ownership();
1008 assert_eq!(ownership.work_owner, chio_core::session::WorkOwner::Request);
1009 assert_eq!(
1010 ownership.result_stream_owner,
1011 chio_core::session::StreamOwner::RequestStream
1012 );
1013 assert_eq!(
1014 ownership.terminal_state_owner,
1015 chio_core::session::WorkOwner::Request
1016 );
1017 }
1018
1019 #[test]
1020 fn complete_request_can_record_cancelled_terminal_state() {
1021 let mut session = Session::new(SessionId::new("sess-1"), "agent-1".to_string(), Vec::new());
1022 let context = make_context("req-1");
1023
1024 session.activate().unwrap();
1025 session
1026 .track_request(&context, OperationKind::ToolCall, true)
1027 .unwrap();
1028 session.request_cancellation(&context.request_id).unwrap();
1029 session
1030 .complete_request_with_terminal_state(
1031 &context.request_id,
1032 OperationTerminalState::Cancelled {
1033 reason: "cancelled by client".to_string(),
1034 },
1035 )
1036 .unwrap();
1037
1038 assert!(session.inflight().is_empty());
1039 assert_eq!(
1040 session.terminal().get(&context.request_id),
1041 Some(&OperationTerminalState::Cancelled {
1042 reason: "cancelled by client".to_string(),
1043 })
1044 );
1045 }
1046
1047 #[test]
1048 fn resource_subscriptions_are_cleared_on_close() {
1049 let mut session = Session::new(SessionId::new("sess-1"), "agent-1".to_string(), Vec::new());
1050
1051 session.activate().unwrap();
1052 session.subscribe_resource("repo://docs/roadmap");
1053
1054 assert!(session.is_resource_subscribed("repo://docs/roadmap"));
1055 assert_eq!(session.subscriptions().len(), 1);
1056
1057 session.close().unwrap();
1058
1059 assert!(!session.is_resource_subscribed("repo://docs/roadmap"));
1060 assert_eq!(session.subscriptions().len(), 0);
1061 }
1062
1063 #[test]
1064 fn session_anchor_rotates_on_auth_context_change() {
1065 let mut session = Session::new(SessionId::new("sess-1"), "agent-1".to_string(), Vec::new());
1066 let initial_anchor = session.session_anchor().clone();
1067 assert_eq!(
1068 session.auth_context(),
1069 &SessionAuthContext::in_process_anonymous()
1070 );
1071
1072 let rotated = session.set_auth_context(SessionAuthContext::streamable_http_static_bearer(
1073 "static-bearer:abcd1234",
1074 "cafebabe",
1075 Some("http://localhost:3000".to_string()),
1076 ));
1077
1078 assert!(rotated);
1079 assert!(session.auth_context().is_authenticated());
1080 assert_eq!(
1081 session.auth_context().principal(),
1082 Some("static-bearer:abcd1234")
1083 );
1084 assert_ne!(session.session_anchor().id(), initial_anchor.id());
1085 assert_eq!(session.session_anchor().auth_epoch(), 1);
1086 assert_ne!(
1087 session.session_anchor().auth_context_hash(),
1088 initial_anchor.auth_context_hash()
1089 );
1090 }
1091
1092 #[test]
1093 fn session_anchor_does_not_rotate_when_auth_context_is_unchanged() {
1094 let mut session = Session::new(SessionId::new("sess-1"), "agent-1".to_string(), Vec::new());
1095 let auth_context = SessionAuthContext::streamable_http_static_bearer(
1096 "static-bearer:abcd1234",
1097 "cafebabe",
1098 Some("http://localhost:3000".to_string()),
1099 );
1100
1101 assert!(session.set_auth_context(auth_context.clone()));
1102 let rotated_anchor = session.session_anchor().clone();
1103 assert!(!session.set_auth_context(auth_context));
1104
1105 assert_eq!(session.session_anchor(), &rotated_anchor);
1106 }
1107
1108 #[test]
1109 fn child_request_is_rejected_after_parent_anchor_rotation() {
1110 let mut session = Session::new(SessionId::new("sess-1"), "agent-1".to_string(), Vec::new());
1111 let parent_context = make_context("req-parent");
1112 let mut child_context = make_context("req-child");
1113 child_context.parent_request_id = Some(parent_context.request_id.clone());
1114
1115 session.activate().unwrap();
1116 session
1117 .track_request(&parent_context, OperationKind::ToolCall, true)
1118 .unwrap();
1119 assert!(
1120 session.set_auth_context(SessionAuthContext::streamable_http_static_bearer(
1121 "static-bearer:abcd1234",
1122 "cafebabe",
1123 Some("http://localhost:3000".to_string()),
1124 ))
1125 );
1126
1127 let err = session
1128 .track_request(&child_context, OperationKind::CreateMessage, true)
1129 .unwrap_err();
1130 assert!(matches!(
1131 err,
1132 SessionError::ParentRequestAnchorMismatch { .. }
1133 ));
1134 }
1135
1136 #[test]
1137 fn url_elicitation_completions_become_session_late_events() {
1138 let mut session = Session::new(SessionId::new("sess-1"), "agent-1".to_string(), Vec::new());
1139 session.register_pending_url_elicitation("elicit-1", Some("task-7".to_string()));
1140
1141 session.queue_elicitation_completion("elicit-1");
1142 session.queue_elicitation_completion("unknown");
1143
1144 assert_eq!(
1145 session.take_late_events(),
1146 vec![LateSessionEvent::ElicitationCompleted {
1147 elicitation_id: "elicit-1".to_string(),
1148 related_task_id: Some("task-7".to_string()),
1149 }]
1150 );
1151 assert!(session.take_late_events().is_empty());
1152 }
1153
1154 #[test]
1155 fn tool_server_events_are_filtered_and_stored_per_session() {
1156 let mut session = Session::new(SessionId::new("sess-1"), "agent-1".to_string(), Vec::new());
1157 session.activate().unwrap();
1158 session.subscribe_resource("repo://docs/roadmap");
1159 session.register_pending_url_elicitation("elicit-2", None);
1160
1161 session.queue_tool_server_event(ToolServerEvent::ResourceUpdated {
1162 uri: "repo://secret/ops".to_string(),
1163 });
1164 session.queue_tool_server_event(ToolServerEvent::ResourceUpdated {
1165 uri: "repo://docs/roadmap".to_string(),
1166 });
1167 session.queue_tool_server_event(ToolServerEvent::ResourcesListChanged);
1168 session.queue_tool_server_event(ToolServerEvent::ElicitationCompleted {
1169 elicitation_id: "elicit-2".to_string(),
1170 });
1171
1172 assert_eq!(
1173 session.take_late_events(),
1174 vec![
1175 LateSessionEvent::ResourceUpdated {
1176 uri: "repo://docs/roadmap".to_string(),
1177 },
1178 LateSessionEvent::ResourcesListChanged,
1179 LateSessionEvent::ElicitationCompleted {
1180 elicitation_id: "elicit-2".to_string(),
1181 related_task_id: None,
1182 },
1183 ]
1184 );
1185 }
1186}