Skip to main content

neuron_types/
traits.rs

1//! Core traits: Provider, Tool, ToolDyn, ContextStrategy, ObservabilityHook, DurableContext.
2
3use std::future::Future;
4use std::time::Duration;
5
6use serde::de::DeserializeOwned;
7use serde::Serialize;
8
9use crate::error::{ContextError, DurableError, EmbeddingError, HookError, ProviderError, ToolError};
10use crate::stream::StreamHandle;
11use crate::types::{
12    CompletionRequest, CompletionResponse, ContentItem, EmbeddingRequest, EmbeddingResponse,
13    Message, ToolContext, ToolDefinition, ToolOutput,
14};
15use crate::wasm::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync};
16
17/// LLM provider trait. Implement this for each provider (Anthropic, OpenAI, Ollama, etc.).
18///
19/// Uses RPITIT (return position impl trait in trait) — Rust 2024 native async.
20/// Not object-safe by design; use generics `<P: Provider>` to compose.
21///
22/// # Example
23///
24/// ```no_run
25/// use std::future::Future;
26/// use neuron_types::*;
27///
28/// struct MyProvider;
29///
30/// impl Provider for MyProvider {
31///     fn complete(&self, request: CompletionRequest)
32///         -> impl Future<Output = Result<CompletionResponse, ProviderError>> + Send
33///     {
34///         async { todo!() }
35///     }
36///
37///     fn complete_stream(&self, request: CompletionRequest)
38///         -> impl Future<Output = Result<StreamHandle, ProviderError>> + Send
39///     {
40///         async { todo!() }
41///     }
42/// }
43/// ```
44pub trait Provider: WasmCompatSend + WasmCompatSync {
45    /// Send a completion request and get a full response.
46    fn complete(
47        &self,
48        request: CompletionRequest,
49    ) -> impl Future<Output = Result<CompletionResponse, ProviderError>> + WasmCompatSend;
50
51    /// Send a completion request and get a stream of events.
52    fn complete_stream(
53        &self,
54        request: CompletionRequest,
55    ) -> impl Future<Output = Result<StreamHandle, ProviderError>> + WasmCompatSend;
56}
57
58/// Embedding provider trait. Implement this for providers that support text embeddings.
59///
60/// Kept separate from [`Provider`] because not all embedding models support chat completion
61/// and not all chat providers support embeddings. Implement both traits on a struct when a
62/// provider supports both capabilities.
63///
64/// Uses RPITIT (return position impl trait in trait) — Rust 2024 native async.
65/// Not object-safe by design; use generics `<E: EmbeddingProvider>` to compose.
66///
67/// # Example
68///
69/// ```no_run
70/// use std::future::Future;
71/// use neuron_types::*;
72///
73/// struct MyEmbeddingProvider;
74///
75/// impl EmbeddingProvider for MyEmbeddingProvider {
76///     fn embed(&self, request: EmbeddingRequest)
77///         -> impl Future<Output = Result<EmbeddingResponse, EmbeddingError>> + Send
78///     {
79///         async { todo!() }
80///     }
81/// }
82/// ```
83pub trait EmbeddingProvider: WasmCompatSend + WasmCompatSync {
84    /// Send an embedding request and get vectors back.
85    ///
86    /// Multiple input strings are batched into a single request. The returned
87    /// `embeddings` vec is in the same order as `request.input`.
88    fn embed(
89        &self,
90        request: EmbeddingRequest,
91    ) -> impl Future<Output = Result<EmbeddingResponse, EmbeddingError>> + WasmCompatSend;
92}
93
94/// Strongly-typed tool trait. Implement this for your tools.
95///
96/// The blanket impl of [`ToolDyn`] handles JSON deserialization/serialization
97/// so you work with concrete Rust types.
98///
99/// # Example
100///
101/// ```no_run
102/// use neuron_types::*;
103/// use serde::Deserialize;
104///
105/// #[derive(Debug, Deserialize, schemars::JsonSchema)]
106/// struct MyArgs { query: String }
107///
108/// struct MyTool;
109/// impl Tool for MyTool {
110///     const NAME: &'static str = "my_tool";
111///     type Args = MyArgs;
112///     type Output = String;
113///     type Error = std::io::Error;
114///
115///     fn definition(&self) -> ToolDefinition { todo!() }
116///     fn call(&self, args: MyArgs, ctx: &ToolContext)
117///         -> impl Future<Output = Result<String, std::io::Error>> + Send
118///     { async { Ok(args.query) } }
119/// }
120/// ```
121pub trait Tool: WasmCompatSend + WasmCompatSync {
122    /// The unique name of this tool.
123    const NAME: &'static str;
124    /// The deserialized input type.
125    type Args: DeserializeOwned + schemars::JsonSchema + WasmCompatSend;
126    /// The serializable output type.
127    type Output: Serialize;
128    /// The tool-specific error type.
129    type Error: std::error::Error + WasmCompatSend + 'static;
130
131    /// Returns the tool definition (name, description, schema).
132    fn definition(&self) -> ToolDefinition;
133
134    /// Execute the tool with typed arguments.
135    fn call(
136        &self,
137        args: Self::Args,
138        ctx: &ToolContext,
139    ) -> impl Future<Output = Result<Self::Output, Self::Error>> + WasmCompatSend;
140}
141
142/// Type-erased tool for dynamic dispatch. Blanket-implemented for all [`Tool`] impls.
143///
144/// This enables heterogeneous tool collections (`HashMap<String, Arc<dyn ToolDyn>>`)
145/// while preserving type safety at the implementation level.
146pub trait ToolDyn: WasmCompatSend + WasmCompatSync {
147    /// The tool's unique name.
148    fn name(&self) -> &str;
149    /// The tool definition (name, description, input schema).
150    fn definition(&self) -> ToolDefinition;
151    /// Execute the tool with a JSON value input, returning a generic output.
152    fn call_dyn<'a>(
153        &'a self,
154        input: serde_json::Value,
155        ctx: &'a ToolContext,
156    ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>>;
157}
158
159/// Blanket implementation: any `Tool` automatically becomes a `ToolDyn`.
160///
161/// Handles:
162/// - Deserializing `serde_json::Value` into `T::Args`
163/// - Calling `T::call(args, ctx)`
164/// - Serializing `T::Output` into `ToolOutput`
165/// - Mapping `T::Error` into `ToolError::ExecutionFailed`
166impl<T: Tool> ToolDyn for T {
167    fn name(&self) -> &str {
168        T::NAME
169    }
170
171    fn definition(&self) -> ToolDefinition {
172        Tool::definition(self)
173    }
174
175    fn call_dyn<'a>(
176        &'a self,
177        input: serde_json::Value,
178        ctx: &'a ToolContext,
179    ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>> {
180        Box::pin(async move {
181            let args: T::Args = serde_json::from_value(input)
182                .map_err(|e| ToolError::InvalidInput(e.to_string()))?;
183
184            let output = self.call(args, ctx).await.map_err(|e| {
185                ToolError::ExecutionFailed(e.to_string().into())
186            })?;
187
188            let structured = serde_json::to_value(&output)
189                .map_err(|e| ToolError::ExecutionFailed(Box::new(e)))?;
190
191            let text = match &structured {
192                serde_json::Value::String(s) => s.clone(),
193                other => other.to_string(),
194            };
195
196            Ok(ToolOutput {
197                content: vec![ContentItem::Text(text)],
198                structured_content: Some(structured),
199                is_error: false,
200            })
201        })
202    }
203}
204
205// --- Context Strategy ---
206
207/// Strategy for compacting conversation context when it exceeds token limits.
208///
209/// Implementations decide when to compact and how to reduce the message list.
210///
211/// # Example
212///
213/// ```no_run
214/// use std::future::Future;
215/// use neuron_types::*;
216///
217/// struct KeepLastN { n: usize }
218/// impl ContextStrategy for KeepLastN {
219///     fn should_compact(&self, _messages: &[Message], token_count: usize) -> bool {
220///         token_count > 100_000
221///     }
222///     fn compact(&self, messages: Vec<Message>)
223///         -> impl Future<Output = Result<Vec<Message>, ContextError>> + Send
224///     {
225///         async move { Ok(messages.into_iter().rev().take(self.n).collect()) }
226///     }
227///     fn token_estimate(&self, messages: &[Message]) -> usize { messages.len() * 100 }
228/// }
229/// ```
230pub trait ContextStrategy: WasmCompatSend + WasmCompatSync {
231    /// Whether compaction should be triggered given the current messages and token count.
232    fn should_compact(&self, messages: &[Message], token_count: usize) -> bool;
233
234    /// Compact the message list to reduce token usage.
235    fn compact(
236        &self,
237        messages: Vec<Message>,
238    ) -> impl Future<Output = Result<Vec<Message>, ContextError>> + WasmCompatSend;
239
240    /// Estimate the token count for a list of messages.
241    fn token_estimate(&self, messages: &[Message]) -> usize;
242}
243
244// --- Observability Hooks ---
245
246/// Events fired during the agentic loop for observability.
247#[derive(Debug)]
248pub enum HookEvent<'a> {
249    /// Start of a loop iteration.
250    LoopIteration {
251        /// The current turn number (0-indexed).
252        turn: usize,
253    },
254    /// Before calling the LLM provider.
255    PreLlmCall {
256        /// The request about to be sent.
257        request: &'a CompletionRequest,
258    },
259    /// After receiving the LLM response.
260    PostLlmCall {
261        /// The response received.
262        response: &'a CompletionResponse,
263    },
264    /// Before executing a tool.
265    PreToolExecution {
266        /// Name of the tool.
267        tool_name: &'a str,
268        /// Input arguments.
269        input: &'a serde_json::Value,
270    },
271    /// After executing a tool.
272    PostToolExecution {
273        /// Name of the tool.
274        tool_name: &'a str,
275        /// The tool's output.
276        output: &'a ToolOutput,
277    },
278    /// Context was compacted.
279    ContextCompaction {
280        /// Token count before compaction.
281        old_tokens: usize,
282        /// Token count after compaction.
283        new_tokens: usize,
284    },
285    /// A session started.
286    SessionStart {
287        /// The session identifier.
288        session_id: &'a str,
289    },
290    /// A session ended.
291    SessionEnd {
292        /// The session identifier.
293        session_id: &'a str,
294    },
295}
296
297/// Action to take after processing a hook event.
298#[derive(Debug)]
299pub enum HookAction {
300    /// Continue normal execution.
301    Continue,
302    /// Skip the current operation and return a rejection message.
303    Skip {
304        /// Reason for skipping.
305        reason: String,
306    },
307    /// Terminate the loop immediately.
308    Terminate {
309        /// Reason for termination.
310        reason: String,
311    },
312}
313
314/// Hook for observability (logging, metrics, telemetry).
315///
316/// Does NOT control execution flow beyond Continue/Skip/Terminate.
317/// For durable execution wrapping, use [`DurableContext`] instead.
318///
319/// # Example
320///
321/// ```no_run
322/// use std::future::Future;
323/// use neuron_types::*;
324///
325/// struct LogHook;
326/// impl ObservabilityHook for LogHook {
327///     fn on_event(&self, event: HookEvent<'_>)
328///         -> impl Future<Output = Result<HookAction, HookError>> + Send
329///     {
330///         async move { println!("{event:?}"); Ok(HookAction::Continue) }
331///     }
332/// }
333/// ```
334pub trait ObservabilityHook: WasmCompatSend + WasmCompatSync {
335    /// Called for each event in the agentic loop.
336    fn on_event(
337        &self,
338        event: HookEvent<'_>,
339    ) -> impl Future<Output = Result<HookAction, HookError>> + WasmCompatSend;
340}
341
342// --- Durable Context ---
343
344/// Options for executing an activity in a durable context.
345#[derive(Debug, Clone)]
346pub struct ActivityOptions {
347    /// Maximum time for the activity to complete.
348    pub start_to_close_timeout: Duration,
349    /// Heartbeat interval for long-running activities.
350    pub heartbeat_timeout: Option<Duration>,
351    /// Retry policy for failed activities.
352    pub retry_policy: Option<RetryPolicy>,
353}
354
355/// Retry policy for durable activities.
356#[derive(Debug, Clone)]
357pub struct RetryPolicy {
358    /// Initial delay before first retry.
359    pub initial_interval: Duration,
360    /// Multiplier for exponential backoff.
361    pub backoff_coefficient: f64,
362    /// Maximum number of retry attempts.
363    pub maximum_attempts: u32,
364    /// Maximum delay between retries.
365    pub maximum_interval: Duration,
366    /// Error types that should not be retried.
367    pub non_retryable_errors: Vec<String>,
368}
369
370/// Wraps side effects for durable execution engines (Temporal, Restate, Inngest).
371///
372/// When present, the agentic loop calls through this instead of directly
373/// calling provider/tools, enabling journaling, replay, and crash recovery.
374pub trait DurableContext: WasmCompatSend + WasmCompatSync {
375    /// Execute an LLM call as a durable activity.
376    fn execute_llm_call(
377        &self,
378        request: CompletionRequest,
379        options: ActivityOptions,
380    ) -> impl Future<Output = Result<CompletionResponse, DurableError>> + WasmCompatSend;
381
382    /// Execute a tool call as a durable activity.
383    fn execute_tool(
384        &self,
385        tool_name: &str,
386        input: serde_json::Value,
387        ctx: &ToolContext,
388        options: ActivityOptions,
389    ) -> impl Future<Output = Result<ToolOutput, DurableError>> + WasmCompatSend;
390
391    /// Wait for an external signal with a timeout.
392    fn wait_for_signal<T: DeserializeOwned + WasmCompatSend>(
393        &self,
394        signal_name: &str,
395        timeout: Duration,
396    ) -> impl Future<Output = Result<Option<T>, DurableError>> + WasmCompatSend;
397
398    /// Whether the workflow should continue-as-new to avoid history bloat.
399    fn should_continue_as_new(&self) -> bool;
400
401    /// Continue the workflow as a new execution with the given state.
402    fn continue_as_new(
403        &self,
404        state: serde_json::Value,
405    ) -> impl Future<Output = Result<(), DurableError>> + WasmCompatSend;
406
407    /// Sleep for a duration (durable — survives replay).
408    fn sleep(
409        &self,
410        duration: Duration,
411    ) -> impl Future<Output = ()> + WasmCompatSend;
412
413    /// Current time (deterministic during replay).
414    fn now(&self) -> chrono::DateTime<chrono::Utc>;
415}
416
417// --- Permission Policy ---
418
419/// Decision from a permission check.
420#[derive(Debug, Clone)]
421pub enum PermissionDecision {
422    /// Allow the tool call.
423    Allow,
424    /// Deny the tool call with a reason.
425    Deny(String),
426    /// Ask the user for confirmation.
427    Ask(String),
428}
429
430/// Policy for checking tool call permissions.
431///
432/// # Example
433///
434/// ```no_run
435/// use neuron_types::*;
436///
437/// struct AllowAll;
438/// impl PermissionPolicy for AllowAll {
439///     fn check(&self, _tool_name: &str, _input: &serde_json::Value) -> PermissionDecision {
440///         PermissionDecision::Allow
441///     }
442/// }
443/// ```
444pub trait PermissionPolicy: WasmCompatSend + WasmCompatSync {
445    /// Check whether a tool call is permitted.
446    fn check(&self, tool_name: &str, input: &serde_json::Value) -> PermissionDecision;
447}