Skip to main content

langshell_core/
lib.rs

1use std::{collections::BTreeMap, fmt, future::Future, pin::Pin, sync::Arc};
2
3use serde::{Deserialize, Serialize};
4use serde_json::{Map, Value, json};
5use sha3::{Digest, Sha3_256};
6
7pub const SNAPSHOT_VERSION: u32 = 1;
8
9#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
10pub struct SessionId(pub String);
11
12impl SessionId {
13    pub fn new(id: impl Into<String>) -> Result<Self, ErrorObject> {
14        let id = id.into();
15        if is_valid_session_id(&id) {
16            Ok(Self(id))
17        } else {
18            Err(ErrorObject::new(
19                "INVALID_SESSION_ID",
20                "Session id must match ^[a-zA-Z0-9_\\-]{1,64}$.",
21            ))
22        }
23    }
24}
25
26impl Default for SessionId {
27    fn default() -> Self {
28        Self("default".to_owned())
29    }
30}
31
32impl fmt::Display for SessionId {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        f.write_str(&self.0)
35    }
36}
37
38fn is_valid_session_id(id: &str) -> bool {
39    !id.is_empty()
40        && id.len() <= 64
41        && id
42            .bytes()
43            .all(|b| b.is_ascii_alphanumeric() || matches!(b, b'_' | b'-'))
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
47#[serde(rename_all = "snake_case")]
48#[derive(Default)]
49pub enum Language {
50    #[default]
51    Python,
52    TypeScript,
53}
54
55
56#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
57pub struct SessionLimits {
58    pub memory_mb: u32,
59    pub cpu_ms: u32,
60    pub wall_ms: u32,
61    pub max_stdout_bytes: u32,
62    pub max_external_calls: u32,
63    pub max_stack_depth: u16,
64}
65
66impl Default for SessionLimits {
67    fn default() -> Self {
68        Self {
69            memory_mb: 64,
70            cpu_ms: 2_000,
71            wall_ms: 5_000,
72            max_stdout_bytes: 65_536,
73            max_external_calls: 32,
74            max_stack_depth: 256,
75        }
76    }
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct SessionMeta {
81    pub id: SessionId,
82    pub created_at: u64,
83    pub last_used_at: u64,
84    pub language: Language,
85    pub limits: SessionLimits,
86    pub snapshot_version: u32,
87}
88
89impl Default for SessionMeta {
90    fn default() -> Self {
91        Self {
92            id: SessionId::default(),
93            created_at: 0,
94            last_used_at: 0,
95            language: Language::Python,
96            limits: SessionLimits::default(),
97            snapshot_version: SNAPSHOT_VERSION,
98        }
99    }
100}
101
102#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
103#[serde(rename_all = "snake_case")]
104#[derive(Default)]
105pub enum SideEffect {
106    #[default]
107    None,
108    Read,
109    Write,
110    Network,
111    Database,
112    ExternalSystem,
113}
114
115
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct CapabilityLimits {
118    pub timeout_ms: u32,
119    pub max_request_bytes: u32,
120    pub max_response_bytes: u32,
121    pub concurrency: u32,
122    pub rate_per_min: Option<u32>,
123}
124
125impl Default for CapabilityLimits {
126    fn default() -> Self {
127        Self {
128            timeout_ms: 5_000,
129            max_request_bytes: 1_048_576,
130            max_response_bytes: 1_048_576,
131            concurrency: 16,
132            rate_per_min: None,
133        }
134    }
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
138#[serde(rename_all = "snake_case")]
139#[derive(Default)]
140pub enum ApprovalPolicy {
141    #[default]
142    None,
143    Auto { rule: String },
144    Manual,
145}
146
147
148#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct Capability {
150    pub name: String,
151    pub description: String,
152    pub input_schema: Value,
153    pub output_schema: Value,
154    pub side_effect: SideEffect,
155    pub limits: CapabilityLimits,
156    pub approval_policy: ApprovalPolicy,
157    pub idempotent: bool,
158}
159
160impl Capability {
161    pub fn new(
162        name: impl Into<String>,
163        description: impl Into<String>,
164        side_effect: SideEffect,
165    ) -> Self {
166        Self {
167            name: name.into(),
168            description: description.into(),
169            input_schema: json!({"type": "array"}),
170            output_schema: json!({}),
171            side_effect,
172            limits: CapabilityLimits::default(),
173            approval_policy: ApprovalPolicy::None,
174            idempotent: true,
175        }
176    }
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct RunRequest {
181    pub session_id: SessionId,
182    pub language: Language,
183    pub code: String,
184    pub inputs: Map<String, Value>,
185    pub timeout_ms: Option<u32>,
186    pub limits: Option<SessionLimits>,
187    pub return_snapshot: bool,
188    pub validate_only: bool,
189}
190
191impl RunRequest {
192    pub fn new(
193        session_id: impl Into<String>,
194        code: impl Into<String>,
195    ) -> Result<Self, ErrorObject> {
196        Ok(Self {
197            session_id: SessionId::new(session_id)?,
198            language: Language::Python,
199            code: code.into(),
200            inputs: Map::new(),
201            timeout_ms: None,
202            limits: None,
203            return_snapshot: false,
204            validate_only: false,
205        })
206    }
207}
208
209impl Default for RunRequest {
210    fn default() -> Self {
211        Self {
212            session_id: SessionId::default(),
213            language: Language::Python,
214            code: String::new(),
215            inputs: Map::new(),
216            timeout_ms: None,
217            limits: None,
218            return_snapshot: false,
219            validate_only: false,
220        }
221    }
222}
223
224#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
225#[serde(rename_all = "snake_case")]
226#[derive(Default)]
227pub enum RunStatus {
228    #[default]
229    Ok,
230    ValidationError,
231    RuntimeError,
232    Timeout,
233    Cancelled,
234    ResourceExhausted,
235    PermissionDenied,
236    WaitingForApproval,
237    Interrupted,
238}
239
240
241#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
242#[serde(rename_all = "snake_case")]
243#[derive(Default)]
244pub enum Severity {
245    Error,
246    Warning,
247    #[default]
248    Info,
249}
250
251
252#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
253pub struct Span {
254    pub line: u32,
255    pub column: u32,
256    pub end_line: Option<u32>,
257    pub end_column: Option<u32>,
258}
259
260impl Default for Span {
261    fn default() -> Self {
262        Self {
263            line: 1,
264            column: 1,
265            end_line: None,
266            end_column: None,
267        }
268    }
269}
270
271#[derive(Debug, Clone, Serialize, Deserialize)]
272pub struct Diagnostic {
273    pub severity: Severity,
274    pub code: String,
275    pub message: String,
276    pub hint: Option<String>,
277    pub span: Option<Span>,
278}
279
280impl Diagnostic {
281    pub fn error(error: &ErrorObject) -> Self {
282        Self {
283            severity: Severity::Error,
284            code: error.code.clone(),
285            message: error.message.clone(),
286            hint: error.hint.clone(),
287            span: error.span,
288        }
289    }
290}
291
292#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
293#[serde(rename_all = "snake_case")]
294#[derive(Default)]
295pub enum CallStatus {
296    #[default]
297    Ok,
298    Error,
299    Timeout,
300    Denied,
301    ApprovalPending,
302}
303
304
305#[derive(Debug, Clone, Serialize, Deserialize)]
306pub struct ErrorObject {
307    pub code: String,
308    pub message: String,
309    pub hint: Option<String>,
310    pub span: Option<Span>,
311}
312
313impl ErrorObject {
314    pub fn new(code: impl Into<String>, message: impl Into<String>) -> Self {
315        Self {
316            code: code.into(),
317            message: message.into(),
318            hint: None,
319            span: None,
320        }
321    }
322
323    pub fn with_hint(mut self, hint: impl Into<String>) -> Self {
324        self.hint = Some(hint.into());
325        self
326    }
327
328    pub fn with_span(mut self, span: Span) -> Self {
329        self.span = Some(span);
330        self
331    }
332}
333
334#[derive(Debug, Clone, Serialize, Deserialize)]
335pub struct ExternalCallRecord {
336    pub name: String,
337    pub side_effect: SideEffect,
338    pub duration_ms: u32,
339    pub status: CallStatus,
340    pub request_digest: String,
341    pub response_digest: Option<String>,
342    pub error: Option<ErrorObject>,
343}
344
345#[derive(Debug, Clone, Serialize, Deserialize)]
346#[derive(Default)]
347pub struct Metrics {
348    pub duration_ms: u32,
349    pub memory_peak_bytes: u64,
350    pub instructions: u64,
351    pub external_calls_count: u32,
352}
353
354
355#[derive(Debug, Clone, Serialize, Deserialize)]
356pub struct RunResult {
357    pub status: RunStatus,
358    pub result: Option<Value>,
359    pub stdout: String,
360    pub stderr: String,
361    pub diagnostics: Vec<Diagnostic>,
362    pub external_calls: Vec<ExternalCallRecord>,
363    pub snapshot_id: Option<String>,
364    pub metrics: Metrics,
365    pub error: Option<ErrorObject>,
366}
367
368impl RunResult {
369    pub fn ok(result: Option<Value>, stdout: String, metrics: Metrics) -> Self {
370        Self {
371            status: RunStatus::Ok,
372            result,
373            stdout,
374            stderr: String::new(),
375            diagnostics: Vec::new(),
376            external_calls: Vec::new(),
377            snapshot_id: None,
378            metrics,
379            error: None,
380        }
381    }
382
383    pub fn error(status: RunStatus, error: ErrorObject, stdout: String, metrics: Metrics) -> Self {
384        Self {
385            status,
386            result: None,
387            stdout,
388            stderr: String::new(),
389            diagnostics: vec![Diagnostic::error(&error)],
390            external_calls: Vec::new(),
391            snapshot_id: None,
392            metrics,
393            error: Some(error),
394        }
395    }
396}
397
398impl Default for RunResult {
399    fn default() -> Self {
400        Self::ok(None, String::new(), Metrics::default())
401    }
402}
403
404#[derive(Debug, Clone)]
405pub struct ToolCallContext {
406    pub name: String,
407    pub args: Vec<Value>,
408    pub kwargs: Map<String, Value>,
409}
410
411#[derive(Debug, Clone)]
412pub struct ToolError {
413    pub code: String,
414    pub message: String,
415}
416
417impl ToolError {
418    pub fn new(code: impl Into<String>, message: impl Into<String>) -> Self {
419        Self {
420            code: code.into(),
421            message: message.into(),
422        }
423    }
424}
425
426pub type ToolResult = Result<Value, ToolError>;
427pub type ToolFuture = Pin<Box<dyn Future<Output = ToolResult> + Send + 'static>>;
428pub type ToolFn = dyn Fn(ToolCallContext) -> ToolFuture + Send + Sync + 'static;
429
430#[derive(Clone)]
431pub struct RegisteredTool {
432    pub capability: Capability,
433    pub async_mode: bool,
434    handler: Arc<ToolFn>,
435}
436
437impl fmt::Debug for RegisteredTool {
438    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
439        f.debug_struct("RegisteredTool")
440            .field("capability", &self.capability)
441            .field("async_mode", &self.async_mode)
442            .finish_non_exhaustive()
443    }
444}
445
446impl RegisteredTool {
447    pub fn sync(
448        capability: Capability,
449        handler: impl Fn(ToolCallContext) -> ToolResult + Send + Sync + 'static,
450    ) -> Self {
451        let handler = Arc::new(move |ctx| {
452            let result = handler(ctx);
453            Box::pin(async move { result }) as ToolFuture
454        });
455        Self {
456            capability,
457            async_mode: false,
458            handler,
459        }
460    }
461
462    pub fn asynchronous(
463        capability: Capability,
464        handler: impl Fn(ToolCallContext) -> ToolFuture + Send + Sync + 'static,
465    ) -> Self {
466        Self {
467            capability,
468            async_mode: true,
469            handler: Arc::new(handler),
470        }
471    }
472
473    pub fn call(&self, ctx: ToolCallContext) -> ToolFuture {
474        (self.handler)(ctx)
475    }
476}
477
478#[derive(Debug, Clone, Default)]
479pub struct ToolRegistry {
480    tools: BTreeMap<String, RegisteredTool>,
481}
482
483impl ToolRegistry {
484    pub fn new() -> Self {
485        Self::default()
486    }
487
488    pub fn register(&mut self, tool: RegisteredTool) -> Result<(), ErrorObject> {
489        if !is_python_identifier(&tool.capability.name) {
490            return Err(ErrorObject::new(
491                "INVALID_TOOL_NAME",
492                format!(
493                    "Capability name '{}' is not a valid Python identifier.",
494                    tool.capability.name
495                ),
496            ));
497        }
498        self.tools.insert(tool.capability.name.clone(), tool);
499        Ok(())
500    }
501
502    pub fn get(&self, name: &str) -> Option<&RegisteredTool> {
503        self.tools.get(name)
504    }
505
506    pub fn contains(&self, name: &str) -> bool {
507        self.tools.contains_key(name)
508    }
509
510    pub fn capabilities(&self) -> Vec<Capability> {
511        self.tools
512            .values()
513            .map(|tool| tool.capability.clone())
514            .collect()
515    }
516
517    pub fn names(&self) -> Vec<String> {
518        self.tools.keys().cloned().collect()
519    }
520}
521
522pub fn is_python_identifier(name: &str) -> bool {
523    let mut chars = name.chars();
524    let Some(first) = chars.next() else {
525        return false;
526    };
527    (first == '_' || first.is_ascii_alphabetic())
528        && chars.all(|c| c == '_' || c.is_ascii_alphanumeric())
529}
530
531pub fn digest_json(value: &Value) -> String {
532    let bytes = serde_json::to_vec(value).unwrap_or_default();
533    let digest = Sha3_256::digest(bytes);
534    to_hex(&digest)
535}
536
537pub fn digest_bytes(bytes: &[u8]) -> String {
538    let digest = Sha3_256::digest(bytes);
539    to_hex(&digest)
540}
541
542fn to_hex(bytes: &[u8]) -> String {
543    const HEX: &[u8; 16] = b"0123456789abcdef";
544    let mut output = String::with_capacity(bytes.len() * 2);
545    for &byte in bytes {
546        output.push(HEX[(byte >> 4) as usize] as char);
547        output.push(HEX[(byte & 0x0f) as usize] as char);
548    }
549    output
550}
551
552pub fn now_unix_ms() -> u64 {
553    std::time::SystemTime::now()
554        .duration_since(std::time::UNIX_EPOCH)
555        .map(|duration| u64::try_from(duration.as_millis()).unwrap_or(u64::MAX))
556        .unwrap_or(0)
557}
558
559#[cfg(test)]
560mod tests {
561    use super::*;
562
563    #[test]
564    fn validates_session_ids() {
565        assert!(SessionId::new("agent-123_ok").is_ok());
566        assert!(SessionId::new("bad/slash").is_err());
567        assert!(SessionId::new("").is_err());
568    }
569
570    #[test]
571    fn validates_python_identifiers() {
572        assert!(is_python_identifier("fetch_json"));
573        assert!(!is_python_identifier("1_fetch"));
574        assert!(!is_python_identifier("fetch-json"));
575    }
576}