neuron_types/traits.rs
1//! Core traits: Provider, Tool, ToolDyn, ContextStrategy, ObservabilityHook, DurableContext.
2
3use std::future::Future;
4use std::time::Duration;
5
6use serde::de::DeserializeOwned;
7use serde::Serialize;
8
9use crate::error::{ContextError, DurableError, HookError, ProviderError, ToolError};
10use crate::stream::StreamHandle;
11use crate::types::{
12 CompletionRequest, CompletionResponse, ContentItem, Message, ToolContext, ToolDefinition,
13 ToolOutput,
14};
15use crate::wasm::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync};
16
17/// LLM provider trait. Implement this for each provider (Anthropic, OpenAI, Ollama, etc.).
18///
19/// Uses RPITIT (return position impl trait in trait) — Rust 2024 native async.
20/// Not object-safe by design; use generics `<P: Provider>` to compose.
21///
22/// # Example
23///
24/// ```ignore
25/// struct MyProvider;
26///
27/// impl Provider for MyProvider {
28/// fn complete(&self, request: CompletionRequest)
29/// -> impl Future<Output = Result<CompletionResponse, ProviderError>> + Send
30/// {
31/// async { todo!() }
32/// }
33///
34/// fn complete_stream(&self, request: CompletionRequest)
35/// -> impl Future<Output = Result<StreamHandle, ProviderError>> + Send
36/// {
37/// async { todo!() }
38/// }
39/// }
40/// ```
41pub trait Provider: WasmCompatSend + WasmCompatSync {
42 /// Send a completion request and get a full response.
43 fn complete(
44 &self,
45 request: CompletionRequest,
46 ) -> impl Future<Output = Result<CompletionResponse, ProviderError>> + WasmCompatSend;
47
48 /// Send a completion request and get a stream of events.
49 fn complete_stream(
50 &self,
51 request: CompletionRequest,
52 ) -> impl Future<Output = Result<StreamHandle, ProviderError>> + WasmCompatSend;
53}
54
55/// Strongly-typed tool trait. Implement this for your tools.
56///
57/// The blanket impl of [`ToolDyn`] handles JSON deserialization/serialization
58/// so you work with concrete Rust types.
59///
60/// # Example
61///
62/// ```ignore
63/// use neuron_types::*;
64/// use serde::Deserialize;
65///
66/// #[derive(Debug, Deserialize, schemars::JsonSchema)]
67/// struct MyArgs { query: String }
68///
69/// struct MyTool;
70/// impl Tool for MyTool {
71/// const NAME: &'static str = "my_tool";
72/// type Args = MyArgs;
73/// type Output = String;
74/// type Error = std::io::Error;
75///
76/// fn definition(&self) -> ToolDefinition { todo!() }
77/// fn call(&self, args: MyArgs, ctx: &ToolContext)
78/// -> impl Future<Output = Result<String, std::io::Error>> + Send
79/// { async { Ok(args.query) } }
80/// }
81/// ```
82pub trait Tool: WasmCompatSend + WasmCompatSync {
83 /// The unique name of this tool.
84 const NAME: &'static str;
85 /// The deserialized input type.
86 type Args: DeserializeOwned + schemars::JsonSchema + WasmCompatSend;
87 /// The serializable output type.
88 type Output: Serialize;
89 /// The tool-specific error type.
90 type Error: std::error::Error + WasmCompatSend + 'static;
91
92 /// Returns the tool definition (name, description, schema).
93 fn definition(&self) -> ToolDefinition;
94
95 /// Execute the tool with typed arguments.
96 fn call(
97 &self,
98 args: Self::Args,
99 ctx: &ToolContext,
100 ) -> impl Future<Output = Result<Self::Output, Self::Error>> + WasmCompatSend;
101}
102
103/// Type-erased tool for dynamic dispatch. Blanket-implemented for all [`Tool`] impls.
104///
105/// This enables heterogeneous tool collections (`HashMap<String, Arc<dyn ToolDyn>>`)
106/// while preserving type safety at the implementation level.
107pub trait ToolDyn: WasmCompatSend + WasmCompatSync {
108 /// The tool's unique name.
109 fn name(&self) -> &str;
110 /// The tool definition (name, description, input schema).
111 fn definition(&self) -> ToolDefinition;
112 /// Execute the tool with a JSON value input, returning a generic output.
113 fn call_dyn<'a>(
114 &'a self,
115 input: serde_json::Value,
116 ctx: &'a ToolContext,
117 ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>>;
118}
119
120/// Blanket implementation: any `Tool` automatically becomes a `ToolDyn`.
121///
122/// Handles:
123/// - Deserializing `serde_json::Value` into `T::Args`
124/// - Calling `T::call(args, ctx)`
125/// - Serializing `T::Output` into `ToolOutput`
126/// - Mapping `T::Error` into `ToolError::ExecutionFailed`
127impl<T: Tool> ToolDyn for T {
128 fn name(&self) -> &str {
129 T::NAME
130 }
131
132 fn definition(&self) -> ToolDefinition {
133 Tool::definition(self)
134 }
135
136 fn call_dyn<'a>(
137 &'a self,
138 input: serde_json::Value,
139 ctx: &'a ToolContext,
140 ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>> {
141 Box::pin(async move {
142 let args: T::Args = serde_json::from_value(input)
143 .map_err(|e| ToolError::InvalidInput(e.to_string()))?;
144
145 let output = self.call(args, ctx).await.map_err(|e| {
146 ToolError::ExecutionFailed(e.to_string().into())
147 })?;
148
149 let structured = serde_json::to_value(&output)
150 .map_err(|e| ToolError::ExecutionFailed(Box::new(e)))?;
151
152 let text = match &structured {
153 serde_json::Value::String(s) => s.clone(),
154 other => other.to_string(),
155 };
156
157 Ok(ToolOutput {
158 content: vec![ContentItem::Text(text)],
159 structured_content: Some(structured),
160 is_error: false,
161 })
162 })
163 }
164}
165
166// --- Context Strategy ---
167
168/// Strategy for compacting conversation context when it exceeds token limits.
169///
170/// Implementations decide when to compact and how to reduce the message list.
171///
172/// # Example
173///
174/// ```ignore
175/// struct KeepLastN { n: usize }
176/// impl ContextStrategy for KeepLastN {
177/// fn should_compact(&self, _messages: &[Message], token_count: usize) -> bool {
178/// token_count > 100_000
179/// }
180/// fn compact(&self, messages: Vec<Message>)
181/// -> impl Future<Output = Result<Vec<Message>, ContextError>> + Send
182/// {
183/// async move { Ok(messages.into_iter().rev().take(self.n).collect()) }
184/// }
185/// fn token_estimate(&self, messages: &[Message]) -> usize { messages.len() * 100 }
186/// }
187/// ```
188pub trait ContextStrategy: WasmCompatSend + WasmCompatSync {
189 /// Whether compaction should be triggered given the current messages and token count.
190 fn should_compact(&self, messages: &[Message], token_count: usize) -> bool;
191
192 /// Compact the message list to reduce token usage.
193 fn compact(
194 &self,
195 messages: Vec<Message>,
196 ) -> impl Future<Output = Result<Vec<Message>, ContextError>> + WasmCompatSend;
197
198 /// Estimate the token count for a list of messages.
199 fn token_estimate(&self, messages: &[Message]) -> usize;
200}
201
202// --- Observability Hooks ---
203
204/// Events fired during the agentic loop for observability.
205#[derive(Debug)]
206pub enum HookEvent<'a> {
207 /// Start of a loop iteration.
208 LoopIteration {
209 /// The current turn number (0-indexed).
210 turn: usize,
211 },
212 /// Before calling the LLM provider.
213 PreLlmCall {
214 /// The request about to be sent.
215 request: &'a CompletionRequest,
216 },
217 /// After receiving the LLM response.
218 PostLlmCall {
219 /// The response received.
220 response: &'a CompletionResponse,
221 },
222 /// Before executing a tool.
223 PreToolExecution {
224 /// Name of the tool.
225 tool_name: &'a str,
226 /// Input arguments.
227 input: &'a serde_json::Value,
228 },
229 /// After executing a tool.
230 PostToolExecution {
231 /// Name of the tool.
232 tool_name: &'a str,
233 /// The tool's output.
234 output: &'a ToolOutput,
235 },
236 /// Context was compacted.
237 ContextCompaction {
238 /// Token count before compaction.
239 old_tokens: usize,
240 /// Token count after compaction.
241 new_tokens: usize,
242 },
243 /// A session started.
244 SessionStart {
245 /// The session identifier.
246 session_id: &'a str,
247 },
248 /// A session ended.
249 SessionEnd {
250 /// The session identifier.
251 session_id: &'a str,
252 },
253}
254
255/// Action to take after processing a hook event.
256#[derive(Debug)]
257pub enum HookAction {
258 /// Continue normal execution.
259 Continue,
260 /// Skip the current operation and return a rejection message.
261 Skip {
262 /// Reason for skipping.
263 reason: String,
264 },
265 /// Terminate the loop immediately.
266 Terminate {
267 /// Reason for termination.
268 reason: String,
269 },
270}
271
272/// Hook for observability (logging, metrics, telemetry).
273///
274/// Does NOT control execution flow beyond Continue/Skip/Terminate.
275/// For durable execution wrapping, use [`DurableContext`] instead.
276///
277/// # Example
278///
279/// ```ignore
280/// struct LogHook;
281/// impl ObservabilityHook for LogHook {
282/// fn on_event(&self, event: HookEvent<'_>)
283/// -> impl Future<Output = Result<HookAction, HookError>> + Send
284/// {
285/// async { println!("{event:?}"); Ok(HookAction::Continue) }
286/// }
287/// }
288/// ```
289pub trait ObservabilityHook: WasmCompatSend + WasmCompatSync {
290 /// Called for each event in the agentic loop.
291 fn on_event(
292 &self,
293 event: HookEvent<'_>,
294 ) -> impl Future<Output = Result<HookAction, HookError>> + WasmCompatSend;
295}
296
297// --- Durable Context ---
298
299/// Options for executing an activity in a durable context.
300#[derive(Debug, Clone)]
301pub struct ActivityOptions {
302 /// Maximum time for the activity to complete.
303 pub start_to_close_timeout: Duration,
304 /// Heartbeat interval for long-running activities.
305 pub heartbeat_timeout: Option<Duration>,
306 /// Retry policy for failed activities.
307 pub retry_policy: Option<RetryPolicy>,
308}
309
310/// Retry policy for durable activities.
311#[derive(Debug, Clone)]
312pub struct RetryPolicy {
313 /// Initial delay before first retry.
314 pub initial_interval: Duration,
315 /// Multiplier for exponential backoff.
316 pub backoff_coefficient: f64,
317 /// Maximum number of retry attempts.
318 pub maximum_attempts: u32,
319 /// Maximum delay between retries.
320 pub maximum_interval: Duration,
321 /// Error types that should not be retried.
322 pub non_retryable_errors: Vec<String>,
323}
324
325/// Wraps side effects for durable execution engines (Temporal, Restate, Inngest).
326///
327/// When present, the agentic loop calls through this instead of directly
328/// calling provider/tools, enabling journaling, replay, and crash recovery.
329pub trait DurableContext: WasmCompatSend + WasmCompatSync {
330 /// Execute an LLM call as a durable activity.
331 fn execute_llm_call(
332 &self,
333 request: CompletionRequest,
334 options: ActivityOptions,
335 ) -> impl Future<Output = Result<CompletionResponse, DurableError>> + WasmCompatSend;
336
337 /// Execute a tool call as a durable activity.
338 fn execute_tool(
339 &self,
340 tool_name: &str,
341 input: serde_json::Value,
342 ctx: &ToolContext,
343 options: ActivityOptions,
344 ) -> impl Future<Output = Result<ToolOutput, DurableError>> + WasmCompatSend;
345
346 /// Wait for an external signal with a timeout.
347 fn wait_for_signal<T: DeserializeOwned + WasmCompatSend>(
348 &self,
349 signal_name: &str,
350 timeout: Duration,
351 ) -> impl Future<Output = Result<Option<T>, DurableError>> + WasmCompatSend;
352
353 /// Whether the workflow should continue-as-new to avoid history bloat.
354 fn should_continue_as_new(&self) -> bool;
355
356 /// Continue the workflow as a new execution with the given state.
357 fn continue_as_new(
358 &self,
359 state: serde_json::Value,
360 ) -> impl Future<Output = Result<(), DurableError>> + WasmCompatSend;
361
362 /// Sleep for a duration (durable — survives replay).
363 fn sleep(
364 &self,
365 duration: Duration,
366 ) -> impl Future<Output = ()> + WasmCompatSend;
367
368 /// Current time (deterministic during replay).
369 fn now(&self) -> chrono::DateTime<chrono::Utc>;
370}
371
372// --- Permission Policy ---
373
374/// Decision from a permission check.
375#[derive(Debug, Clone)]
376pub enum PermissionDecision {
377 /// Allow the tool call.
378 Allow,
379 /// Deny the tool call with a reason.
380 Deny(String),
381 /// Ask the user for confirmation.
382 Ask(String),
383}
384
385/// Policy for checking tool call permissions.
386///
387/// # Example
388///
389/// ```ignore
390/// struct AllowAll;
391/// impl PermissionPolicy for AllowAll {
392/// fn check(&self, _tool_name: &str, _input: &serde_json::Value) -> PermissionDecision {
393/// PermissionDecision::Allow
394/// }
395/// }
396/// ```
397pub trait PermissionPolicy: WasmCompatSend + WasmCompatSync {
398 /// Check whether a tool call is permitted.
399 fn check(&self, tool_name: &str, input: &serde_json::Value) -> PermissionDecision;
400}