Skip to main content

neuron_types/
traits.rs

1//! Core traits: Provider, Tool, ToolDyn, ContextStrategy, ObservabilityHook, DurableContext.
2
3use std::future::Future;
4use std::time::Duration;
5
6use serde::de::DeserializeOwned;
7use serde::Serialize;
8
9use crate::error::{ContextError, DurableError, HookError, ProviderError, ToolError};
10use crate::stream::StreamHandle;
11use crate::types::{
12    CompletionRequest, CompletionResponse, ContentItem, Message, ToolContext, ToolDefinition,
13    ToolOutput,
14};
15use crate::wasm::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync};
16
17/// LLM provider trait. Implement this for each provider (Anthropic, OpenAI, Ollama, etc.).
18///
19/// Uses RPITIT (return position impl trait in trait) — Rust 2024 native async.
20/// Not object-safe by design; use generics `<P: Provider>` to compose.
21///
22/// # Example
23///
24/// ```ignore
25/// struct MyProvider;
26///
27/// impl Provider for MyProvider {
28///     fn complete(&self, request: CompletionRequest)
29///         -> impl Future<Output = Result<CompletionResponse, ProviderError>> + Send
30///     {
31///         async { todo!() }
32///     }
33///
34///     fn complete_stream(&self, request: CompletionRequest)
35///         -> impl Future<Output = Result<StreamHandle, ProviderError>> + Send
36///     {
37///         async { todo!() }
38///     }
39/// }
40/// ```
41pub trait Provider: WasmCompatSend + WasmCompatSync {
42    /// Send a completion request and get a full response.
43    fn complete(
44        &self,
45        request: CompletionRequest,
46    ) -> impl Future<Output = Result<CompletionResponse, ProviderError>> + WasmCompatSend;
47
48    /// Send a completion request and get a stream of events.
49    fn complete_stream(
50        &self,
51        request: CompletionRequest,
52    ) -> impl Future<Output = Result<StreamHandle, ProviderError>> + WasmCompatSend;
53}
54
55/// Strongly-typed tool trait. Implement this for your tools.
56///
57/// The blanket impl of [`ToolDyn`] handles JSON deserialization/serialization
58/// so you work with concrete Rust types.
59///
60/// # Example
61///
62/// ```ignore
63/// use neuron_types::*;
64/// use serde::Deserialize;
65///
66/// #[derive(Debug, Deserialize, schemars::JsonSchema)]
67/// struct MyArgs { query: String }
68///
69/// struct MyTool;
70/// impl Tool for MyTool {
71///     const NAME: &'static str = "my_tool";
72///     type Args = MyArgs;
73///     type Output = String;
74///     type Error = std::io::Error;
75///
76///     fn definition(&self) -> ToolDefinition { todo!() }
77///     fn call(&self, args: MyArgs, ctx: &ToolContext)
78///         -> impl Future<Output = Result<String, std::io::Error>> + Send
79///     { async { Ok(args.query) } }
80/// }
81/// ```
82pub trait Tool: WasmCompatSend + WasmCompatSync {
83    /// The unique name of this tool.
84    const NAME: &'static str;
85    /// The deserialized input type.
86    type Args: DeserializeOwned + schemars::JsonSchema + WasmCompatSend;
87    /// The serializable output type.
88    type Output: Serialize;
89    /// The tool-specific error type.
90    type Error: std::error::Error + WasmCompatSend + 'static;
91
92    /// Returns the tool definition (name, description, schema).
93    fn definition(&self) -> ToolDefinition;
94
95    /// Execute the tool with typed arguments.
96    fn call(
97        &self,
98        args: Self::Args,
99        ctx: &ToolContext,
100    ) -> impl Future<Output = Result<Self::Output, Self::Error>> + WasmCompatSend;
101}
102
103/// Type-erased tool for dynamic dispatch. Blanket-implemented for all [`Tool`] impls.
104///
105/// This enables heterogeneous tool collections (`HashMap<String, Arc<dyn ToolDyn>>`)
106/// while preserving type safety at the implementation level.
107pub trait ToolDyn: WasmCompatSend + WasmCompatSync {
108    /// The tool's unique name.
109    fn name(&self) -> &str;
110    /// The tool definition (name, description, input schema).
111    fn definition(&self) -> ToolDefinition;
112    /// Execute the tool with a JSON value input, returning a generic output.
113    fn call_dyn<'a>(
114        &'a self,
115        input: serde_json::Value,
116        ctx: &'a ToolContext,
117    ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>>;
118}
119
120/// Blanket implementation: any `Tool` automatically becomes a `ToolDyn`.
121///
122/// Handles:
123/// - Deserializing `serde_json::Value` into `T::Args`
124/// - Calling `T::call(args, ctx)`
125/// - Serializing `T::Output` into `ToolOutput`
126/// - Mapping `T::Error` into `ToolError::ExecutionFailed`
127impl<T: Tool> ToolDyn for T {
128    fn name(&self) -> &str {
129        T::NAME
130    }
131
132    fn definition(&self) -> ToolDefinition {
133        Tool::definition(self)
134    }
135
136    fn call_dyn<'a>(
137        &'a self,
138        input: serde_json::Value,
139        ctx: &'a ToolContext,
140    ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>> {
141        Box::pin(async move {
142            let args: T::Args = serde_json::from_value(input)
143                .map_err(|e| ToolError::InvalidInput(e.to_string()))?;
144
145            let output = self.call(args, ctx).await.map_err(|e| {
146                ToolError::ExecutionFailed(e.to_string().into())
147            })?;
148
149            let structured = serde_json::to_value(&output)
150                .map_err(|e| ToolError::ExecutionFailed(Box::new(e)))?;
151
152            let text = match &structured {
153                serde_json::Value::String(s) => s.clone(),
154                other => other.to_string(),
155            };
156
157            Ok(ToolOutput {
158                content: vec![ContentItem::Text(text)],
159                structured_content: Some(structured),
160                is_error: false,
161            })
162        })
163    }
164}
165
166// --- Context Strategy ---
167
168/// Strategy for compacting conversation context when it exceeds token limits.
169///
170/// Implementations decide when to compact and how to reduce the message list.
171///
172/// # Example
173///
174/// ```ignore
175/// struct KeepLastN { n: usize }
176/// impl ContextStrategy for KeepLastN {
177///     fn should_compact(&self, _messages: &[Message], token_count: usize) -> bool {
178///         token_count > 100_000
179///     }
180///     fn compact(&self, messages: Vec<Message>)
181///         -> impl Future<Output = Result<Vec<Message>, ContextError>> + Send
182///     {
183///         async move { Ok(messages.into_iter().rev().take(self.n).collect()) }
184///     }
185///     fn token_estimate(&self, messages: &[Message]) -> usize { messages.len() * 100 }
186/// }
187/// ```
188pub trait ContextStrategy: WasmCompatSend + WasmCompatSync {
189    /// Whether compaction should be triggered given the current messages and token count.
190    fn should_compact(&self, messages: &[Message], token_count: usize) -> bool;
191
192    /// Compact the message list to reduce token usage.
193    fn compact(
194        &self,
195        messages: Vec<Message>,
196    ) -> impl Future<Output = Result<Vec<Message>, ContextError>> + WasmCompatSend;
197
198    /// Estimate the token count for a list of messages.
199    fn token_estimate(&self, messages: &[Message]) -> usize;
200}
201
202// --- Observability Hooks ---
203
204/// Events fired during the agentic loop for observability.
205#[derive(Debug)]
206pub enum HookEvent<'a> {
207    /// Start of a loop iteration.
208    LoopIteration {
209        /// The current turn number (0-indexed).
210        turn: usize,
211    },
212    /// Before calling the LLM provider.
213    PreLlmCall {
214        /// The request about to be sent.
215        request: &'a CompletionRequest,
216    },
217    /// After receiving the LLM response.
218    PostLlmCall {
219        /// The response received.
220        response: &'a CompletionResponse,
221    },
222    /// Before executing a tool.
223    PreToolExecution {
224        /// Name of the tool.
225        tool_name: &'a str,
226        /// Input arguments.
227        input: &'a serde_json::Value,
228    },
229    /// After executing a tool.
230    PostToolExecution {
231        /// Name of the tool.
232        tool_name: &'a str,
233        /// The tool's output.
234        output: &'a ToolOutput,
235    },
236    /// Context was compacted.
237    ContextCompaction {
238        /// Token count before compaction.
239        old_tokens: usize,
240        /// Token count after compaction.
241        new_tokens: usize,
242    },
243    /// A session started.
244    SessionStart {
245        /// The session identifier.
246        session_id: &'a str,
247    },
248    /// A session ended.
249    SessionEnd {
250        /// The session identifier.
251        session_id: &'a str,
252    },
253}
254
255/// Action to take after processing a hook event.
256#[derive(Debug)]
257pub enum HookAction {
258    /// Continue normal execution.
259    Continue,
260    /// Skip the current operation and return a rejection message.
261    Skip {
262        /// Reason for skipping.
263        reason: String,
264    },
265    /// Terminate the loop immediately.
266    Terminate {
267        /// Reason for termination.
268        reason: String,
269    },
270}
271
272/// Hook for observability (logging, metrics, telemetry).
273///
274/// Does NOT control execution flow beyond Continue/Skip/Terminate.
275/// For durable execution wrapping, use [`DurableContext`] instead.
276///
277/// # Example
278///
279/// ```ignore
280/// struct LogHook;
281/// impl ObservabilityHook for LogHook {
282///     fn on_event(&self, event: HookEvent<'_>)
283///         -> impl Future<Output = Result<HookAction, HookError>> + Send
284///     {
285///         async { println!("{event:?}"); Ok(HookAction::Continue) }
286///     }
287/// }
288/// ```
289pub trait ObservabilityHook: WasmCompatSend + WasmCompatSync {
290    /// Called for each event in the agentic loop.
291    fn on_event(
292        &self,
293        event: HookEvent<'_>,
294    ) -> impl Future<Output = Result<HookAction, HookError>> + WasmCompatSend;
295}
296
297// --- Durable Context ---
298
299/// Options for executing an activity in a durable context.
300#[derive(Debug, Clone)]
301pub struct ActivityOptions {
302    /// Maximum time for the activity to complete.
303    pub start_to_close_timeout: Duration,
304    /// Heartbeat interval for long-running activities.
305    pub heartbeat_timeout: Option<Duration>,
306    /// Retry policy for failed activities.
307    pub retry_policy: Option<RetryPolicy>,
308}
309
310/// Retry policy for durable activities.
311#[derive(Debug, Clone)]
312pub struct RetryPolicy {
313    /// Initial delay before first retry.
314    pub initial_interval: Duration,
315    /// Multiplier for exponential backoff.
316    pub backoff_coefficient: f64,
317    /// Maximum number of retry attempts.
318    pub maximum_attempts: u32,
319    /// Maximum delay between retries.
320    pub maximum_interval: Duration,
321    /// Error types that should not be retried.
322    pub non_retryable_errors: Vec<String>,
323}
324
325/// Wraps side effects for durable execution engines (Temporal, Restate, Inngest).
326///
327/// When present, the agentic loop calls through this instead of directly
328/// calling provider/tools, enabling journaling, replay, and crash recovery.
329pub trait DurableContext: WasmCompatSend + WasmCompatSync {
330    /// Execute an LLM call as a durable activity.
331    fn execute_llm_call(
332        &self,
333        request: CompletionRequest,
334        options: ActivityOptions,
335    ) -> impl Future<Output = Result<CompletionResponse, DurableError>> + WasmCompatSend;
336
337    /// Execute a tool call as a durable activity.
338    fn execute_tool(
339        &self,
340        tool_name: &str,
341        input: serde_json::Value,
342        ctx: &ToolContext,
343        options: ActivityOptions,
344    ) -> impl Future<Output = Result<ToolOutput, DurableError>> + WasmCompatSend;
345
346    /// Wait for an external signal with a timeout.
347    fn wait_for_signal<T: DeserializeOwned + WasmCompatSend>(
348        &self,
349        signal_name: &str,
350        timeout: Duration,
351    ) -> impl Future<Output = Result<Option<T>, DurableError>> + WasmCompatSend;
352
353    /// Whether the workflow should continue-as-new to avoid history bloat.
354    fn should_continue_as_new(&self) -> bool;
355
356    /// Continue the workflow as a new execution with the given state.
357    fn continue_as_new(
358        &self,
359        state: serde_json::Value,
360    ) -> impl Future<Output = Result<(), DurableError>> + WasmCompatSend;
361
362    /// Sleep for a duration (durable — survives replay).
363    fn sleep(
364        &self,
365        duration: Duration,
366    ) -> impl Future<Output = ()> + WasmCompatSend;
367
368    /// Current time (deterministic during replay).
369    fn now(&self) -> chrono::DateTime<chrono::Utc>;
370}
371
372// --- Permission Policy ---
373
374/// Decision from a permission check.
375#[derive(Debug, Clone)]
376pub enum PermissionDecision {
377    /// Allow the tool call.
378    Allow,
379    /// Deny the tool call with a reason.
380    Deny(String),
381    /// Ask the user for confirmation.
382    Ask(String),
383}
384
385/// Policy for checking tool call permissions.
386///
387/// # Example
388///
389/// ```ignore
390/// struct AllowAll;
391/// impl PermissionPolicy for AllowAll {
392///     fn check(&self, _tool_name: &str, _input: &serde_json::Value) -> PermissionDecision {
393///         PermissionDecision::Allow
394///     }
395/// }
396/// ```
397pub trait PermissionPolicy: WasmCompatSend + WasmCompatSync {
398    /// Check whether a tool call is permitted.
399    fn check(&self, tool_name: &str, input: &serde_json::Value) -> PermissionDecision;
400}