Skip to main content

lash_core/session_model/
mod.rs

1pub mod context;
2pub use lash_sansio::session_model::message;
3pub use lash_sansio::session_model::prompt;
4
5use std::sync::Arc;
6use tokio::sync::mpsc;
7
8use crate::llm::types::{LlmEventSender, LlmStreamEvent};
9use crate::plugin::PluginMessage;
10use crate::provider::ProviderHandle;
11use crate::{ExecutionMode, StandardContextApproach};
12
13pub use lash_sansio::session_model::{
14    ConversationRecord, ErrorEnvelope, MAIN_AGENT_INTRO, Message, MessageRole, Part, PartKind,
15    PromptBuiltin, PromptSlot, PromptTemplate, PromptTemplateEntry, PromptTemplateSection,
16    PruneState, SessionEvent, StateSnapshotEvent, TokenUsage, ToolEvent,
17    TurnTerminationPolicyState, default_prompt_template, format_tool_output_content,
18    format_tool_result_content, fresh_message_id, make_error_envelope, make_error_event,
19    reassign_part_ids, render_prompt, render_transcript_prompt, shared_parts,
20};
21
22#[derive(Clone, Debug, PartialEq)]
23pub struct ModeEvent {
24    pub mode_id: ExecutionMode,
25    pub payload: serde_json::Value,
26}
27
28impl ModeEvent {
29    pub fn typed<T>(mode_id: ExecutionMode, event: T) -> Result<Self, serde_json::Error>
30    where
31        T: serde::Serialize,
32    {
33        Ok(Self {
34            mode_id,
35            payload: serde_json::to_value(event)?,
36        })
37    }
38
39    pub fn decode<T>(&self, expected_mode: &ExecutionMode) -> Result<Option<T>, serde_json::Error>
40    where
41        T: for<'de> serde::Deserialize<'de>,
42    {
43        if &self.mode_id != expected_mode {
44            return Ok(None);
45        }
46        serde_json::from_value(self.payload.clone()).map(Some)
47    }
48}
49
50impl serde::Serialize for ModeEvent {
51    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
52    where
53        S: serde::Serializer,
54    {
55        #[derive(serde::Serialize)]
56        struct Tagged<'a> {
57            mode_id: &'a ExecutionMode,
58            payload: &'a serde_json::Value,
59        }
60        Tagged {
61            mode_id: &self.mode_id,
62            payload: &self.payload,
63        }
64        .serialize(serializer)
65    }
66}
67
68impl<'de> serde::Deserialize<'de> for ModeEvent {
69    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
70    where
71        D: serde::Deserializer<'de>,
72    {
73        let value = serde_json::Value::deserialize(deserializer)?;
74        if let Some(object) = value.as_object()
75            && let (Some(mode_id), Some(payload)) = (object.get("mode_id"), object.get("payload"))
76        {
77            let mode_id =
78                ExecutionMode::deserialize(mode_id.clone()).map_err(serde::de::Error::custom)?;
79            return Ok(Self {
80                mode_id,
81                payload: payload.clone(),
82            });
83        }
84        Err(serde::de::Error::custom(
85            "mode events must be tagged with mode_id and payload",
86        ))
87    }
88}
89
90pub type SessionEventRecord = lash_sansio::session_model::SessionEventRecord<ModeEvent>;
91
92/// Send an event to the channel if it's still open.
93pub(crate) async fn send_event(tx: &mpsc::Sender<SessionEvent>, event: SessionEvent) {
94    if !tx.is_closed() {
95        let _ = tx.send(event).await;
96    }
97}
98
99pub(crate) fn plugin_message_to_message(plugin_message: &PluginMessage) -> Message {
100    let message_id = fresh_message_id();
101    let mut parts = if plugin_message.parts.is_empty() {
102        vec![Part {
103            id: format!("{message_id}.p0"),
104            kind: PartKind::Text,
105            content: plugin_message.content.clone(),
106            attachment: None,
107            tool_call_id: None,
108            tool_name: None,
109            tool_replay: None,
110            prune_state: PruneState::Intact,
111            reasoning_meta: None,
112            response_meta: None,
113        }]
114    } else {
115        plugin_message.parts.clone()
116    };
117    reassign_part_ids(&message_id, &mut parts);
118    Message {
119        id: message_id,
120        role: plugin_message.role,
121        parts: Arc::new(parts),
122        origin: Some(crate::MessageOrigin::Plugin {
123            plugin_id: "plugin".to_string(),
124            transient: false,
125        }),
126    }
127}
128
129/// Resolved session policy for a running session.
130///
131/// `provider` is a [`ProviderHandle`] — serializes through
132/// [`crate::provider::ProviderSpec`], rebuilt via the global
133/// [`crate::provider::ProviderRegistry`] on load. Hosts register the
134/// concrete provider types they support at startup.
135#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
136pub struct SessionPolicy {
137    pub model: String,
138    pub provider: ProviderHandle,
139    pub max_context_tokens: Option<usize>,
140    pub model_variant: Option<String>,
141    pub session_id: Option<String>,
142    #[serde(default)]
143    pub autonomous: bool,
144    pub max_turns: Option<usize>,
145    pub execution_mode: ExecutionMode,
146    #[serde(default, skip_serializing_if = "Option::is_none")]
147    pub standard_context_approach: Option<StandardContextApproach>,
148    #[serde(default, skip_serializing_if = "crate::PromptLayer::is_empty")]
149    pub prompt: crate::PromptLayer,
150}
151
152impl SessionPolicy {
153    /// Drop policy fields that only apply to the standard execution mode.
154    pub fn normalize_for_execution_mode(&mut self) {
155        if self.execution_mode != ExecutionMode::standard() {
156            self.standard_context_approach = None;
157        }
158    }
159
160    pub fn normalized_for_execution_mode(mut self) -> Self {
161        self.normalize_for_execution_mode();
162        self
163    }
164}
165
166/// Reusable session configuration overlay.
167///
168/// `SessionSpec` is the public configuration shape for callers that want to
169/// describe either a root session or a child session without constructing the
170/// resolved runtime-only [`SessionPolicy`] directly.
171#[derive(Clone, Debug, PartialEq, Eq)]
172pub struct SessionSpec {
173    inherit: bool,
174    pub provider: Option<ProviderHandle>,
175    pub model: Option<String>,
176    pub model_variant: Option<Option<String>>,
177    pub execution_mode: Option<ExecutionMode>,
178    pub max_context_tokens: Option<usize>,
179    pub max_turns: Option<Option<usize>>,
180    pub prompt: Option<crate::PromptLayer>,
181}
182
183impl SessionSpec {
184    /// Create an explicit root-style spec. Unset fields resolve from the
185    /// runtime's core defaults.
186    pub fn new() -> Self {
187        Self {
188            inherit: false,
189            provider: None,
190            model: None,
191            model_variant: None,
192            execution_mode: None,
193            max_context_tokens: None,
194            max_turns: None,
195            prompt: None,
196        }
197    }
198
199    /// Create a parent-relative spec. Unset fields inherit from the live
200    /// parent policy at resolution time.
201    pub fn inherit() -> Self {
202        Self {
203            inherit: true,
204            ..Self::new()
205        }
206    }
207
208    pub fn inherits(&self) -> bool {
209        self.inherit
210    }
211
212    pub fn provider(mut self, provider: ProviderHandle) -> Self {
213        self.provider = Some(provider);
214        self
215    }
216
217    pub fn model(mut self, model: impl Into<String>, variant: Option<String>) -> Self {
218        self.model = Some(model.into());
219        self.model_variant = Some(variant);
220        self
221    }
222
223    pub fn model_variant(mut self, variant: impl Into<String>) -> Self {
224        self.model_variant = Some(Some(variant.into()));
225        self
226    }
227
228    pub fn clear_model_variant(mut self) -> Self {
229        self.model_variant = Some(None);
230        self
231    }
232
233    pub fn mode(mut self, mode: ExecutionMode) -> Self {
234        self.execution_mode = Some(mode);
235        self
236    }
237
238    pub fn max_context_tokens(mut self, max_context_tokens: usize) -> Self {
239        self.max_context_tokens = Some(max_context_tokens);
240        self
241    }
242
243    pub fn max_turns(mut self, max_turns: usize) -> Self {
244        self.max_turns = Some(Some(max_turns));
245        self
246    }
247
248    pub fn clear_max_turns(mut self) -> Self {
249        self.max_turns = Some(None);
250        self
251    }
252
253    pub fn prompt_layer(mut self, prompt: crate::PromptLayer) -> Self {
254        self.prompt = Some(prompt);
255        self
256    }
257
258    pub fn resolve_against(&self, base: &SessionPolicy) -> SessionPolicy {
259        let mut policy = base.clone();
260        if let Some(provider) = self.provider.as_ref() {
261            policy.provider = provider.clone();
262            if self.model.is_none() {
263                let model = provider.default_model().to_string();
264                policy.model_variant = provider.default_model_variant(&model).map(str::to_string);
265                policy.model = model;
266            }
267        }
268        if let Some(model) = self.model.as_ref() {
269            policy.model = model.clone();
270        }
271        if let Some(model_variant) = self.model_variant.as_ref() {
272            policy.model_variant = model_variant.clone();
273        }
274        if let Some(max_context_tokens) = self.max_context_tokens {
275            policy.max_context_tokens = Some(max_context_tokens);
276        }
277        if let Some(max_turns) = self.max_turns {
278            policy.max_turns = max_turns;
279        }
280        if let Some(execution_mode) = self.execution_mode.as_ref() {
281            policy.execution_mode = execution_mode.clone();
282        }
283        if let Some(prompt) = self.prompt.as_ref() {
284            policy.prompt = prompt.clone();
285        }
286        policy
287    }
288}
289
290impl Default for SessionSpec {
291    fn default() -> Self {
292        Self::new()
293    }
294}
295
296impl Default for SessionPolicy {
297    fn default() -> Self {
298        Self {
299            model: String::new(),
300            provider: ProviderHandle::default(),
301            max_context_tokens: None,
302            model_variant: None,
303            session_id: None,
304            autonomous: false,
305            max_turns: None,
306            execution_mode: ExecutionMode::standard(),
307            standard_context_approach: Some(StandardContextApproach::default()),
308            prompt: crate::PromptLayer::default(),
309        }
310    }
311}
312
313pub(crate) fn transport_stream_events(
314    provider: &ProviderHandle,
315    requested: Option<tokio::sync::mpsc::UnboundedSender<LlmStreamEvent>>,
316) -> Option<LlmEventSender> {
317    if let Some(requested) = requested {
318        return Some(make_stream_event_sender(requested));
319    }
320
321    if provider.requires_streaming() {
322        let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<LlmStreamEvent>();
323        drop(rx);
324        Some(make_stream_event_sender(tx))
325    } else {
326        None
327    }
328}
329
330fn make_stream_event_sender(
331    tx: tokio::sync::mpsc::UnboundedSender<LlmStreamEvent>,
332) -> LlmEventSender {
333    LlmEventSender::new(move |event| {
334        let _ = tx.send(event);
335    })
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    #[test]
343    fn mode_event_writes_tagged_payload() {
344        let event = ModeEvent::typed(
345            ExecutionMode::new("test"),
346            serde_json::json!({ "value": 42 }),
347        )
348        .expect("typed event");
349        let serialized = serde_json::to_value(event).expect("serialize");
350        assert_eq!(serialized["mode_id"], "test");
351        assert!(serialized.get("payload").is_some());
352    }
353}