Skip to main content

langshell_core/
lib.rs

1use std::{collections::BTreeMap, fmt, future::Future, pin::Pin, sync::Arc};
2
3use serde::{Deserialize, Serialize};
4use serde_json::{Map, Value, json};
5use sha3::{Digest, Sha3_256};
6
7pub const SNAPSHOT_VERSION: u32 = 1;
8
9#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
10pub struct SessionId(pub String);
11
12impl SessionId {
13    pub fn new(id: impl Into<String>) -> Result<Self, ErrorObject> {
14        let id = id.into();
15        if is_valid_session_id(&id) {
16            Ok(Self(id))
17        } else {
18            Err(ErrorObject::new(
19                "INVALID_SESSION_ID",
20                "Session id must match ^[a-zA-Z0-9_\\-]{1,64}$.",
21            ))
22        }
23    }
24}
25
26impl Default for SessionId {
27    fn default() -> Self {
28        Self("default".to_owned())
29    }
30}
31
32impl fmt::Display for SessionId {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        f.write_str(&self.0)
35    }
36}
37
38fn is_valid_session_id(id: &str) -> bool {
39    !id.is_empty()
40        && id.len() <= 64
41        && id
42            .bytes()
43            .all(|b| b.is_ascii_alphanumeric() || matches!(b, b'_' | b'-'))
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
47#[serde(rename_all = "snake_case")]
48#[derive(Default)]
49pub enum Language {
50    #[default]
51    Python,
52    TypeScript,
53}
54
55#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
56pub struct SessionLimits {
57    pub memory_mb: u32,
58    pub cpu_ms: u32,
59    pub wall_ms: u32,
60    pub max_stdout_bytes: u32,
61    pub max_external_calls: u32,
62    pub max_stack_depth: u16,
63}
64
65impl Default for SessionLimits {
66    fn default() -> Self {
67        Self {
68            memory_mb: 64,
69            cpu_ms: 2_000,
70            wall_ms: 5_000,
71            max_stdout_bytes: 65_536,
72            max_external_calls: 32,
73            max_stack_depth: 256,
74        }
75    }
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct SessionMeta {
80    pub id: SessionId,
81    pub created_at: u64,
82    pub last_used_at: u64,
83    pub language: Language,
84    pub limits: SessionLimits,
85    pub snapshot_version: u32,
86}
87
88impl Default for SessionMeta {
89    fn default() -> Self {
90        Self {
91            id: SessionId::default(),
92            created_at: 0,
93            last_used_at: 0,
94            language: Language::Python,
95            limits: SessionLimits::default(),
96            snapshot_version: SNAPSHOT_VERSION,
97        }
98    }
99}
100
101#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
102#[serde(rename_all = "snake_case")]
103#[derive(Default)]
104pub enum SideEffect {
105    #[default]
106    None,
107    Read,
108    Write,
109    Network,
110    Database,
111    ExternalSystem,
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct CapabilityLimits {
116    pub timeout_ms: u32,
117    pub max_request_bytes: u32,
118    pub max_response_bytes: u32,
119    pub concurrency: u32,
120    pub rate_per_min: Option<u32>,
121}
122
123impl Default for CapabilityLimits {
124    fn default() -> Self {
125        Self {
126            timeout_ms: 5_000,
127            max_request_bytes: 1_048_576,
128            max_response_bytes: 1_048_576,
129            concurrency: 16,
130            rate_per_min: None,
131        }
132    }
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
136#[serde(rename_all = "snake_case")]
137#[derive(Default)]
138pub enum ApprovalPolicy {
139    #[default]
140    None,
141    Auto {
142        rule: String,
143    },
144    Manual,
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct Capability {
149    pub name: String,
150    pub description: String,
151    pub input_schema: Value,
152    pub output_schema: Value,
153    pub side_effect: SideEffect,
154    pub limits: CapabilityLimits,
155    pub approval_policy: ApprovalPolicy,
156    pub idempotent: bool,
157}
158
159impl Capability {
160    pub fn new(
161        name: impl Into<String>,
162        description: impl Into<String>,
163        side_effect: SideEffect,
164    ) -> Self {
165        Self {
166            name: name.into(),
167            description: description.into(),
168            input_schema: json!({"type": "array"}),
169            output_schema: json!({}),
170            side_effect,
171            limits: CapabilityLimits::default(),
172            approval_policy: ApprovalPolicy::None,
173            idempotent: true,
174        }
175    }
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct RunRequest {
180    pub session_id: SessionId,
181    pub language: Language,
182    pub code: String,
183    pub inputs: Map<String, Value>,
184    pub timeout_ms: Option<u32>,
185    pub limits: Option<SessionLimits>,
186    pub return_snapshot: bool,
187    pub validate_only: bool,
188}
189
190impl RunRequest {
191    pub fn new(
192        session_id: impl Into<String>,
193        code: impl Into<String>,
194    ) -> Result<Self, ErrorObject> {
195        Ok(Self {
196            session_id: SessionId::new(session_id)?,
197            language: Language::Python,
198            code: code.into(),
199            inputs: Map::new(),
200            timeout_ms: None,
201            limits: None,
202            return_snapshot: false,
203            validate_only: false,
204        })
205    }
206}
207
208impl Default for RunRequest {
209    fn default() -> Self {
210        Self {
211            session_id: SessionId::default(),
212            language: Language::Python,
213            code: String::new(),
214            inputs: Map::new(),
215            timeout_ms: None,
216            limits: None,
217            return_snapshot: false,
218            validate_only: false,
219        }
220    }
221}
222
223#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
224#[serde(rename_all = "snake_case")]
225#[derive(Default)]
226pub enum RunStatus {
227    #[default]
228    Ok,
229    ValidationError,
230    RuntimeError,
231    Timeout,
232    Cancelled,
233    ResourceExhausted,
234    PermissionDenied,
235    WaitingForApproval,
236    Interrupted,
237}
238
239#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
240#[serde(rename_all = "snake_case")]
241#[derive(Default)]
242pub enum Severity {
243    Error,
244    Warning,
245    #[default]
246    Info,
247}
248
249#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
250pub struct Span {
251    pub line: u32,
252    pub column: u32,
253    pub end_line: Option<u32>,
254    pub end_column: Option<u32>,
255}
256
257impl Default for Span {
258    fn default() -> Self {
259        Self {
260            line: 1,
261            column: 1,
262            end_line: None,
263            end_column: None,
264        }
265    }
266}
267
268#[derive(Debug, Clone, Serialize, Deserialize)]
269pub struct Diagnostic {
270    pub severity: Severity,
271    pub code: String,
272    pub message: String,
273    pub hint: Option<String>,
274    pub span: Option<Span>,
275}
276
277impl Diagnostic {
278    pub fn error(error: &ErrorObject) -> Self {
279        Self {
280            severity: Severity::Error,
281            code: error.code.clone(),
282            message: error.message.clone(),
283            hint: error.hint.clone(),
284            span: error.span,
285        }
286    }
287}
288
289#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
290#[serde(rename_all = "snake_case")]
291#[derive(Default)]
292pub enum CallStatus {
293    #[default]
294    Ok,
295    Error,
296    Timeout,
297    Denied,
298    ApprovalPending,
299}
300
301#[derive(Debug, Clone, Serialize, Deserialize)]
302pub struct ErrorObject {
303    pub code: String,
304    pub message: String,
305    pub hint: Option<String>,
306    pub span: Option<Span>,
307}
308
309impl ErrorObject {
310    pub fn new(code: impl Into<String>, message: impl Into<String>) -> Self {
311        Self {
312            code: code.into(),
313            message: message.into(),
314            hint: None,
315            span: None,
316        }
317    }
318
319    pub fn with_hint(mut self, hint: impl Into<String>) -> Self {
320        self.hint = Some(hint.into());
321        self
322    }
323
324    pub fn with_span(mut self, span: Span) -> Self {
325        self.span = Some(span);
326        self
327    }
328}
329
330#[derive(Debug, Clone, Serialize, Deserialize)]
331pub struct ExternalCallRecord {
332    pub name: String,
333    pub side_effect: SideEffect,
334    pub duration_ms: u32,
335    pub status: CallStatus,
336    pub request_digest: String,
337    pub response_digest: Option<String>,
338    pub error: Option<ErrorObject>,
339}
340
341#[derive(Debug, Clone, Serialize, Deserialize, Default)]
342pub struct Metrics {
343    pub duration_ms: u32,
344    pub memory_peak_bytes: u64,
345    pub instructions: u64,
346    pub external_calls_count: u32,
347}
348
349#[derive(Debug, Clone, Serialize, Deserialize)]
350pub struct RunResult {
351    pub status: RunStatus,
352    pub result: Option<Value>,
353    pub stdout: String,
354    pub stderr: String,
355    pub diagnostics: Vec<Diagnostic>,
356    pub external_calls: Vec<ExternalCallRecord>,
357    pub snapshot_id: Option<String>,
358    pub metrics: Metrics,
359    pub error: Option<ErrorObject>,
360}
361
362impl RunResult {
363    pub fn ok(result: Option<Value>, stdout: String, metrics: Metrics) -> Self {
364        Self {
365            status: RunStatus::Ok,
366            result,
367            stdout,
368            stderr: String::new(),
369            diagnostics: Vec::new(),
370            external_calls: Vec::new(),
371            snapshot_id: None,
372            metrics,
373            error: None,
374        }
375    }
376
377    pub fn error(status: RunStatus, error: ErrorObject, stdout: String, metrics: Metrics) -> Self {
378        Self {
379            status,
380            result: None,
381            stdout,
382            stderr: String::new(),
383            diagnostics: vec![Diagnostic::error(&error)],
384            external_calls: Vec::new(),
385            snapshot_id: None,
386            metrics,
387            error: Some(error),
388        }
389    }
390}
391
392impl Default for RunResult {
393    fn default() -> Self {
394        Self::ok(None, String::new(), Metrics::default())
395    }
396}
397
398#[derive(Debug, Clone)]
399pub struct ToolCallContext {
400    pub name: String,
401    pub args: Vec<Value>,
402    pub kwargs: Map<String, Value>,
403}
404
405#[derive(Debug, Clone)]
406pub struct ToolError {
407    pub code: String,
408    pub message: String,
409}
410
411impl ToolError {
412    pub fn new(code: impl Into<String>, message: impl Into<String>) -> Self {
413        Self {
414            code: code.into(),
415            message: message.into(),
416        }
417    }
418}
419
420pub type ToolResult = Result<Value, ToolError>;
421pub type RuntimeFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
422
423pub trait LanguageRuntime: fmt::Debug + Send + Sync {
424    fn language(&self) -> Language;
425
426    fn create_session(
427        &self,
428        session_id: SessionId,
429        limits: Option<SessionLimits>,
430    ) -> RuntimeFuture<'_, Result<(), ErrorObject>>;
431
432    fn run(&self, request: RunRequest) -> RuntimeFuture<'_, RunResult>;
433
434    fn destroy_session(
435        &self,
436        session_id: SessionId,
437    ) -> RuntimeFuture<'_, Result<bool, ErrorObject>>;
438
439    fn list_sessions(&self) -> RuntimeFuture<'_, Result<Vec<SessionId>, ErrorObject>>;
440
441    fn snapshot_session(
442        &self,
443        session_id: SessionId,
444    ) -> RuntimeFuture<'_, Result<Vec<u8>, ErrorObject>>;
445
446    fn restore_session(
447        &self,
448        snapshot: Vec<u8>,
449        session_id: Option<SessionId>,
450    ) -> RuntimeFuture<'_, Result<SessionId, ErrorObject>>;
451
452    fn can_restore_snapshot(&self, snapshot: &[u8]) -> bool;
453}
454pub type ToolFuture = Pin<Box<dyn Future<Output = ToolResult> + Send + 'static>>;
455pub type ToolFn = dyn Fn(ToolCallContext) -> ToolFuture + Send + Sync + 'static;
456
457#[derive(Clone)]
458pub struct RegisteredTool {
459    pub capability: Capability,
460    pub async_mode: bool,
461    handler: Arc<ToolFn>,
462}
463
464impl fmt::Debug for RegisteredTool {
465    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
466        f.debug_struct("RegisteredTool")
467            .field("capability", &self.capability)
468            .field("async_mode", &self.async_mode)
469            .finish_non_exhaustive()
470    }
471}
472
473impl RegisteredTool {
474    pub fn sync(
475        capability: Capability,
476        handler: impl Fn(ToolCallContext) -> ToolResult + Send + Sync + 'static,
477    ) -> Self {
478        let handler = Arc::new(move |ctx| {
479            let result = handler(ctx);
480            Box::pin(async move { result }) as ToolFuture
481        });
482        Self {
483            capability,
484            async_mode: false,
485            handler,
486        }
487    }
488
489    pub fn asynchronous(
490        capability: Capability,
491        handler: impl Fn(ToolCallContext) -> ToolFuture + Send + Sync + 'static,
492    ) -> Self {
493        Self {
494            capability,
495            async_mode: true,
496            handler: Arc::new(handler),
497        }
498    }
499
500    pub fn call(&self, ctx: ToolCallContext) -> ToolFuture {
501        (self.handler)(ctx)
502    }
503}
504
505#[derive(Debug, Clone, Default)]
506pub struct ToolRegistry {
507    tools: BTreeMap<String, RegisteredTool>,
508}
509
510impl ToolRegistry {
511    pub fn new() -> Self {
512        Self::default()
513    }
514
515    pub fn register(&mut self, tool: RegisteredTool) -> Result<(), ErrorObject> {
516        if !is_python_identifier(&tool.capability.name) {
517            return Err(ErrorObject::new(
518                "INVALID_TOOL_NAME",
519                format!(
520                    "Capability name '{}' is not a valid Python identifier.",
521                    tool.capability.name
522                ),
523            ));
524        }
525        self.tools.insert(tool.capability.name.clone(), tool);
526        Ok(())
527    }
528
529    pub fn get(&self, name: &str) -> Option<&RegisteredTool> {
530        self.tools.get(name)
531    }
532
533    pub fn contains(&self, name: &str) -> bool {
534        self.tools.contains_key(name)
535    }
536
537    pub fn capabilities(&self) -> Vec<Capability> {
538        self.tools
539            .values()
540            .map(|tool| tool.capability.clone())
541            .collect()
542    }
543
544    pub fn names(&self) -> Vec<String> {
545        self.tools.keys().cloned().collect()
546    }
547}
548
549pub fn is_python_identifier(name: &str) -> bool {
550    let mut chars = name.chars();
551    let Some(first) = chars.next() else {
552        return false;
553    };
554    (first == '_' || first.is_ascii_alphabetic())
555        && chars.all(|c| c == '_' || c.is_ascii_alphanumeric())
556}
557
558pub fn digest_json(value: &Value) -> String {
559    let bytes = serde_json::to_vec(value).unwrap_or_default();
560    let digest = Sha3_256::digest(bytes);
561    to_hex(&digest)
562}
563
564pub fn digest_bytes(bytes: &[u8]) -> String {
565    let digest = Sha3_256::digest(bytes);
566    to_hex(&digest)
567}
568
569fn to_hex(bytes: &[u8]) -> String {
570    const HEX: &[u8; 16] = b"0123456789abcdef";
571    let mut output = String::with_capacity(bytes.len() * 2);
572    for &byte in bytes {
573        output.push(HEX[(byte >> 4) as usize] as char);
574        output.push(HEX[(byte & 0x0f) as usize] as char);
575    }
576    output
577}
578
579pub fn now_unix_ms() -> u64 {
580    std::time::SystemTime::now()
581        .duration_since(std::time::UNIX_EPOCH)
582        .map(|duration| u64::try_from(duration.as_millis()).unwrap_or(u64::MAX))
583        .unwrap_or(0)
584}
585
586#[cfg(test)]
587mod tests {
588    use super::*;
589
590    #[test]
591    fn validates_session_ids() {
592        assert!(SessionId::new("agent-123_ok").is_ok());
593        assert!(SessionId::new("bad/slash").is_err());
594        assert!(SessionId::new("").is_err());
595    }
596
597    #[test]
598    fn validates_python_identifiers() {
599        assert!(is_python_identifier("fetch_json"));
600        assert!(!is_python_identifier("1_fetch"));
601        assert!(!is_python_identifier("fetch-json"));
602    }
603}