1use std::{collections::BTreeMap, fmt, future::Future, pin::Pin, sync::Arc};
2
3use serde::{Deserialize, Serialize};
4use serde_json::{Map, Value, json};
5use sha3::{Digest, Sha3_256};
6
7pub const SNAPSHOT_VERSION: u32 = 1;
8
9#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
10pub struct SessionId(pub String);
11
12impl SessionId {
13 pub fn new(id: impl Into<String>) -> Result<Self, ErrorObject> {
14 let id = id.into();
15 if is_valid_session_id(&id) {
16 Ok(Self(id))
17 } else {
18 Err(ErrorObject::new(
19 "INVALID_SESSION_ID",
20 "Session id must match ^[a-zA-Z0-9_\\-]{1,64}$.",
21 ))
22 }
23 }
24}
25
26impl Default for SessionId {
27 fn default() -> Self {
28 Self("default".to_owned())
29 }
30}
31
32impl fmt::Display for SessionId {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 f.write_str(&self.0)
35 }
36}
37
38fn is_valid_session_id(id: &str) -> bool {
39 !id.is_empty()
40 && id.len() <= 64
41 && id
42 .bytes()
43 .all(|b| b.is_ascii_alphanumeric() || matches!(b, b'_' | b'-'))
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
47#[serde(rename_all = "snake_case")]
48#[derive(Default)]
49pub enum Language {
50 #[default]
51 Python,
52 TypeScript,
53}
54
55
56#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
57pub struct SessionLimits {
58 pub memory_mb: u32,
59 pub cpu_ms: u32,
60 pub wall_ms: u32,
61 pub max_stdout_bytes: u32,
62 pub max_external_calls: u32,
63 pub max_stack_depth: u16,
64}
65
66impl Default for SessionLimits {
67 fn default() -> Self {
68 Self {
69 memory_mb: 64,
70 cpu_ms: 2_000,
71 wall_ms: 5_000,
72 max_stdout_bytes: 65_536,
73 max_external_calls: 32,
74 max_stack_depth: 256,
75 }
76 }
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct SessionMeta {
81 pub id: SessionId,
82 pub created_at: u64,
83 pub last_used_at: u64,
84 pub language: Language,
85 pub limits: SessionLimits,
86 pub snapshot_version: u32,
87}
88
89impl Default for SessionMeta {
90 fn default() -> Self {
91 Self {
92 id: SessionId::default(),
93 created_at: 0,
94 last_used_at: 0,
95 language: Language::Python,
96 limits: SessionLimits::default(),
97 snapshot_version: SNAPSHOT_VERSION,
98 }
99 }
100}
101
102#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
103#[serde(rename_all = "snake_case")]
104#[derive(Default)]
105pub enum SideEffect {
106 #[default]
107 None,
108 Read,
109 Write,
110 Network,
111 Database,
112 ExternalSystem,
113}
114
115
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct CapabilityLimits {
118 pub timeout_ms: u32,
119 pub max_request_bytes: u32,
120 pub max_response_bytes: u32,
121 pub concurrency: u32,
122 pub rate_per_min: Option<u32>,
123}
124
125impl Default for CapabilityLimits {
126 fn default() -> Self {
127 Self {
128 timeout_ms: 5_000,
129 max_request_bytes: 1_048_576,
130 max_response_bytes: 1_048_576,
131 concurrency: 16,
132 rate_per_min: None,
133 }
134 }
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
138#[serde(rename_all = "snake_case")]
139#[derive(Default)]
140pub enum ApprovalPolicy {
141 #[default]
142 None,
143 Auto { rule: String },
144 Manual,
145}
146
147
148#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct Capability {
150 pub name: String,
151 pub description: String,
152 pub input_schema: Value,
153 pub output_schema: Value,
154 pub side_effect: SideEffect,
155 pub limits: CapabilityLimits,
156 pub approval_policy: ApprovalPolicy,
157 pub idempotent: bool,
158}
159
160impl Capability {
161 pub fn new(
162 name: impl Into<String>,
163 description: impl Into<String>,
164 side_effect: SideEffect,
165 ) -> Self {
166 Self {
167 name: name.into(),
168 description: description.into(),
169 input_schema: json!({"type": "array"}),
170 output_schema: json!({}),
171 side_effect,
172 limits: CapabilityLimits::default(),
173 approval_policy: ApprovalPolicy::None,
174 idempotent: true,
175 }
176 }
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct RunRequest {
181 pub session_id: SessionId,
182 pub language: Language,
183 pub code: String,
184 pub inputs: Map<String, Value>,
185 pub timeout_ms: Option<u32>,
186 pub limits: Option<SessionLimits>,
187 pub return_snapshot: bool,
188 pub validate_only: bool,
189}
190
191impl RunRequest {
192 pub fn new(
193 session_id: impl Into<String>,
194 code: impl Into<String>,
195 ) -> Result<Self, ErrorObject> {
196 Ok(Self {
197 session_id: SessionId::new(session_id)?,
198 language: Language::Python,
199 code: code.into(),
200 inputs: Map::new(),
201 timeout_ms: None,
202 limits: None,
203 return_snapshot: false,
204 validate_only: false,
205 })
206 }
207}
208
209impl Default for RunRequest {
210 fn default() -> Self {
211 Self {
212 session_id: SessionId::default(),
213 language: Language::Python,
214 code: String::new(),
215 inputs: Map::new(),
216 timeout_ms: None,
217 limits: None,
218 return_snapshot: false,
219 validate_only: false,
220 }
221 }
222}
223
224#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
225#[serde(rename_all = "snake_case")]
226#[derive(Default)]
227pub enum RunStatus {
228 #[default]
229 Ok,
230 ValidationError,
231 RuntimeError,
232 Timeout,
233 Cancelled,
234 ResourceExhausted,
235 PermissionDenied,
236 WaitingForApproval,
237 Interrupted,
238}
239
240
241#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
242#[serde(rename_all = "snake_case")]
243#[derive(Default)]
244pub enum Severity {
245 Error,
246 Warning,
247 #[default]
248 Info,
249}
250
251
252#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
253pub struct Span {
254 pub line: u32,
255 pub column: u32,
256 pub end_line: Option<u32>,
257 pub end_column: Option<u32>,
258}
259
260impl Default for Span {
261 fn default() -> Self {
262 Self {
263 line: 1,
264 column: 1,
265 end_line: None,
266 end_column: None,
267 }
268 }
269}
270
271#[derive(Debug, Clone, Serialize, Deserialize)]
272pub struct Diagnostic {
273 pub severity: Severity,
274 pub code: String,
275 pub message: String,
276 pub hint: Option<String>,
277 pub span: Option<Span>,
278}
279
280impl Diagnostic {
281 pub fn error(error: &ErrorObject) -> Self {
282 Self {
283 severity: Severity::Error,
284 code: error.code.clone(),
285 message: error.message.clone(),
286 hint: error.hint.clone(),
287 span: error.span,
288 }
289 }
290}
291
292#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
293#[serde(rename_all = "snake_case")]
294#[derive(Default)]
295pub enum CallStatus {
296 #[default]
297 Ok,
298 Error,
299 Timeout,
300 Denied,
301 ApprovalPending,
302}
303
304
305#[derive(Debug, Clone, Serialize, Deserialize)]
306pub struct ErrorObject {
307 pub code: String,
308 pub message: String,
309 pub hint: Option<String>,
310 pub span: Option<Span>,
311}
312
313impl ErrorObject {
314 pub fn new(code: impl Into<String>, message: impl Into<String>) -> Self {
315 Self {
316 code: code.into(),
317 message: message.into(),
318 hint: None,
319 span: None,
320 }
321 }
322
323 pub fn with_hint(mut self, hint: impl Into<String>) -> Self {
324 self.hint = Some(hint.into());
325 self
326 }
327
328 pub fn with_span(mut self, span: Span) -> Self {
329 self.span = Some(span);
330 self
331 }
332}
333
334#[derive(Debug, Clone, Serialize, Deserialize)]
335pub struct ExternalCallRecord {
336 pub name: String,
337 pub side_effect: SideEffect,
338 pub duration_ms: u32,
339 pub status: CallStatus,
340 pub request_digest: String,
341 pub response_digest: Option<String>,
342 pub error: Option<ErrorObject>,
343}
344
345#[derive(Debug, Clone, Serialize, Deserialize)]
346#[derive(Default)]
347pub struct Metrics {
348 pub duration_ms: u32,
349 pub memory_peak_bytes: u64,
350 pub instructions: u64,
351 pub external_calls_count: u32,
352}
353
354
355#[derive(Debug, Clone, Serialize, Deserialize)]
356pub struct RunResult {
357 pub status: RunStatus,
358 pub result: Option<Value>,
359 pub stdout: String,
360 pub stderr: String,
361 pub diagnostics: Vec<Diagnostic>,
362 pub external_calls: Vec<ExternalCallRecord>,
363 pub snapshot_id: Option<String>,
364 pub metrics: Metrics,
365 pub error: Option<ErrorObject>,
366}
367
368impl RunResult {
369 pub fn ok(result: Option<Value>, stdout: String, metrics: Metrics) -> Self {
370 Self {
371 status: RunStatus::Ok,
372 result,
373 stdout,
374 stderr: String::new(),
375 diagnostics: Vec::new(),
376 external_calls: Vec::new(),
377 snapshot_id: None,
378 metrics,
379 error: None,
380 }
381 }
382
383 pub fn error(status: RunStatus, error: ErrorObject, stdout: String, metrics: Metrics) -> Self {
384 Self {
385 status,
386 result: None,
387 stdout,
388 stderr: String::new(),
389 diagnostics: vec![Diagnostic::error(&error)],
390 external_calls: Vec::new(),
391 snapshot_id: None,
392 metrics,
393 error: Some(error),
394 }
395 }
396}
397
398impl Default for RunResult {
399 fn default() -> Self {
400 Self::ok(None, String::new(), Metrics::default())
401 }
402}
403
404#[derive(Debug, Clone)]
405pub struct ToolCallContext {
406 pub name: String,
407 pub args: Vec<Value>,
408 pub kwargs: Map<String, Value>,
409}
410
411#[derive(Debug, Clone)]
412pub struct ToolError {
413 pub code: String,
414 pub message: String,
415}
416
417impl ToolError {
418 pub fn new(code: impl Into<String>, message: impl Into<String>) -> Self {
419 Self {
420 code: code.into(),
421 message: message.into(),
422 }
423 }
424}
425
426pub type ToolResult = Result<Value, ToolError>;
427pub type ToolFuture = Pin<Box<dyn Future<Output = ToolResult> + Send + 'static>>;
428pub type ToolFn = dyn Fn(ToolCallContext) -> ToolFuture + Send + Sync + 'static;
429
430#[derive(Clone)]
431pub struct RegisteredTool {
432 pub capability: Capability,
433 pub async_mode: bool,
434 handler: Arc<ToolFn>,
435}
436
437impl fmt::Debug for RegisteredTool {
438 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
439 f.debug_struct("RegisteredTool")
440 .field("capability", &self.capability)
441 .field("async_mode", &self.async_mode)
442 .finish_non_exhaustive()
443 }
444}
445
446impl RegisteredTool {
447 pub fn sync(
448 capability: Capability,
449 handler: impl Fn(ToolCallContext) -> ToolResult + Send + Sync + 'static,
450 ) -> Self {
451 let handler = Arc::new(move |ctx| {
452 let result = handler(ctx);
453 Box::pin(async move { result }) as ToolFuture
454 });
455 Self {
456 capability,
457 async_mode: false,
458 handler,
459 }
460 }
461
462 pub fn asynchronous(
463 capability: Capability,
464 handler: impl Fn(ToolCallContext) -> ToolFuture + Send + Sync + 'static,
465 ) -> Self {
466 Self {
467 capability,
468 async_mode: true,
469 handler: Arc::new(handler),
470 }
471 }
472
473 pub fn call(&self, ctx: ToolCallContext) -> ToolFuture {
474 (self.handler)(ctx)
475 }
476}
477
478#[derive(Debug, Clone, Default)]
479pub struct ToolRegistry {
480 tools: BTreeMap<String, RegisteredTool>,
481}
482
483impl ToolRegistry {
484 pub fn new() -> Self {
485 Self::default()
486 }
487
488 pub fn register(&mut self, tool: RegisteredTool) -> Result<(), ErrorObject> {
489 if !is_python_identifier(&tool.capability.name) {
490 return Err(ErrorObject::new(
491 "INVALID_TOOL_NAME",
492 format!(
493 "Capability name '{}' is not a valid Python identifier.",
494 tool.capability.name
495 ),
496 ));
497 }
498 self.tools.insert(tool.capability.name.clone(), tool);
499 Ok(())
500 }
501
502 pub fn get(&self, name: &str) -> Option<&RegisteredTool> {
503 self.tools.get(name)
504 }
505
506 pub fn contains(&self, name: &str) -> bool {
507 self.tools.contains_key(name)
508 }
509
510 pub fn capabilities(&self) -> Vec<Capability> {
511 self.tools
512 .values()
513 .map(|tool| tool.capability.clone())
514 .collect()
515 }
516
517 pub fn names(&self) -> Vec<String> {
518 self.tools.keys().cloned().collect()
519 }
520}
521
522pub fn is_python_identifier(name: &str) -> bool {
523 let mut chars = name.chars();
524 let Some(first) = chars.next() else {
525 return false;
526 };
527 (first == '_' || first.is_ascii_alphabetic())
528 && chars.all(|c| c == '_' || c.is_ascii_alphanumeric())
529}
530
531pub fn digest_json(value: &Value) -> String {
532 let bytes = serde_json::to_vec(value).unwrap_or_default();
533 let digest = Sha3_256::digest(bytes);
534 to_hex(&digest)
535}
536
537pub fn digest_bytes(bytes: &[u8]) -> String {
538 let digest = Sha3_256::digest(bytes);
539 to_hex(&digest)
540}
541
542fn to_hex(bytes: &[u8]) -> String {
543 const HEX: &[u8; 16] = b"0123456789abcdef";
544 let mut output = String::with_capacity(bytes.len() * 2);
545 for &byte in bytes {
546 output.push(HEX[(byte >> 4) as usize] as char);
547 output.push(HEX[(byte & 0x0f) as usize] as char);
548 }
549 output
550}
551
552pub fn now_unix_ms() -> u64 {
553 std::time::SystemTime::now()
554 .duration_since(std::time::UNIX_EPOCH)
555 .map(|duration| u64::try_from(duration.as_millis()).unwrap_or(u64::MAX))
556 .unwrap_or(0)
557}
558
559#[cfg(test)]
560mod tests {
561 use super::*;
562
563 #[test]
564 fn validates_session_ids() {
565 assert!(SessionId::new("agent-123_ok").is_ok());
566 assert!(SessionId::new("bad/slash").is_err());
567 assert!(SessionId::new("").is_err());
568 }
569
570 #[test]
571 fn validates_python_identifiers() {
572 assert!(is_python_identifier("fetch_json"));
573 assert!(!is_python_identifier("1_fetch"));
574 assert!(!is_python_identifier("fetch-json"));
575 }
576}