Skip to main content

lash_core/session_model/
mod.rs

1pub mod context;
2pub use lash_sansio::session_model::message;
3pub use lash_sansio::session_model::prompt;
4
5use std::sync::Arc;
6use tokio::sync::mpsc;
7
8use crate::ModelSpec;
9use crate::llm::types::{LlmEventSender, LlmStreamEvent};
10use crate::plugin::PluginMessage;
11use crate::provider::{ProviderBinding, ProviderHandle, ProviderResolutionError};
12
13pub use lash_sansio::format_tool_output_content;
14pub use lash_sansio::session_model::{
15    ConversationRecord, ErrorEnvelope, MAIN_AGENT_INTRO, Message, MessageRole, Part, PartKind,
16    PromptBuiltin, PromptSlot, PromptTemplate, PromptTemplateEntry, PromptTemplateSection,
17    PruneState, SessionEvent, TokenUsage, TurnTerminationPolicyState, default_prompt_template,
18    make_error_envelope, make_error_event, reassign_part_ids, render_prompt,
19    render_transcript_prompt, shared_parts,
20};
21
22pub fn fresh_message_id() -> String {
23    format!("m{}", uuid::Uuid::new_v4().simple())
24}
25
26#[derive(Clone, Debug, PartialEq)]
27pub struct ProtocolEvent {
28    pub plugin_id: String,
29    pub payload: serde_json::Value,
30}
31
32impl ProtocolEvent {
33    pub fn typed<T>(plugin_id: impl Into<String>, event: T) -> Result<Self, serde_json::Error>
34    where
35        T: serde::Serialize,
36    {
37        Ok(Self {
38            plugin_id: plugin_id.into(),
39            payload: serde_json::to_value(event)?,
40        })
41    }
42
43    pub fn decode<T>(&self, expected_plugin_id: &str) -> Result<Option<T>, serde_json::Error>
44    where
45        T: for<'de> serde::Deserialize<'de>,
46    {
47        if self.plugin_id != expected_plugin_id {
48            return Ok(None);
49        }
50        serde_json::from_value(self.payload.clone()).map(Some)
51    }
52}
53
54impl serde::Serialize for ProtocolEvent {
55    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
56    where
57        S: serde::Serializer,
58    {
59        #[derive(serde::Serialize)]
60        struct Tagged<'a> {
61            plugin_id: &'a str,
62            payload: &'a serde_json::Value,
63        }
64        Tagged {
65            plugin_id: &self.plugin_id,
66            payload: &self.payload,
67        }
68        .serialize(serializer)
69    }
70}
71
72impl<'de> serde::Deserialize<'de> for ProtocolEvent {
73    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
74    where
75        D: serde::Deserializer<'de>,
76    {
77        let value = serde_json::Value::deserialize(deserializer)?;
78        if let Some(object) = value.as_object()
79            && let (Some(plugin_id), Some(payload)) =
80                (object.get("plugin_id"), object.get("payload"))
81        {
82            return Ok(Self {
83                plugin_id: plugin_id
84                    .as_str()
85                    .ok_or_else(|| serde::de::Error::custom("plugin_id must be a string"))?
86                    .to_string(),
87                payload: payload.clone(),
88            });
89        }
90        Err(serde::de::Error::custom(
91            "protocol events must be tagged with plugin_id and payload",
92        ))
93    }
94}
95
96pub type SessionEventRecord = lash_sansio::session_model::SessionEventRecord<ProtocolEvent>;
97
98/// Send an event to the channel if it's still open.
99pub(crate) async fn send_event(tx: &mpsc::Sender<SessionEvent>, event: SessionEvent) {
100    if !tx.is_closed() {
101        let _ = tx.send(event).await;
102    }
103}
104
105pub(crate) fn plugin_message_to_message(plugin_message: &PluginMessage) -> Message {
106    let message_id = fresh_message_id();
107    let mut parts = if plugin_message.parts.is_empty() {
108        vec![Part {
109            id: format!("{message_id}.p0"),
110            kind: PartKind::Text,
111            content: plugin_message.content.clone(),
112            attachment: None,
113            tool_call_id: None,
114            tool_name: None,
115            tool_replay: None,
116            prune_state: PruneState::Intact,
117            reasoning_meta: None,
118            response_meta: None,
119        }]
120    } else {
121        plugin_message.parts.clone()
122    };
123    reassign_part_ids(&message_id, &mut parts);
124    Message {
125        id: message_id,
126        role: plugin_message.role,
127        parts: Arc::new(parts),
128        origin: plugin_message.origin.clone().or_else(|| {
129            Some(crate::MessageOrigin::Plugin {
130                plugin_id: "plugin".to_string(),
131                transient: false,
132            })
133        }),
134    }
135}
136
137#[derive(Clone, Debug, Default, PartialEq, Eq)]
138pub struct SessionPolicy {
139    pub model: ModelSpec,
140    pub provider_id: String,
141    pub session_id: Option<String>,
142    pub autonomous: bool,
143    pub max_turns: Option<usize>,
144    pub prompt: crate::PromptLayer,
145}
146
147impl SessionPolicy {
148    pub fn recorded_provider_id(&self) -> &str {
149        self.provider_id.trim()
150    }
151
152    pub fn model_id(&self) -> &str {
153        &self.model.id
154    }
155
156    pub fn model_variant(&self) -> Option<&str> {
157        self.model.variant.as_deref()
158    }
159
160    pub fn context_window_tokens(&self) -> usize {
161        self.model.context_window_tokens()
162    }
163}
164
165impl serde::Serialize for SessionPolicy {
166    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
167    where
168        S: serde::Serializer,
169    {
170        use serde::ser::SerializeStruct;
171
172        let mut fields = 5;
173        if !self.prompt.is_empty() {
174            fields += 1;
175        }
176        let mut state = serializer.serialize_struct("SessionPolicy", fields)?;
177        state.serialize_field("model", &self.model)?;
178        state.serialize_field("provider_id", self.recorded_provider_id())?;
179        state.serialize_field("session_id", &self.session_id)?;
180        state.serialize_field("autonomous", &self.autonomous)?;
181        state.serialize_field("max_turns", &self.max_turns)?;
182        if !self.prompt.is_empty() {
183            state.serialize_field("prompt", &self.prompt)?;
184        }
185        state.end()
186    }
187}
188
189impl<'de> serde::Deserialize<'de> for SessionPolicy {
190    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
191    where
192        D: serde::Deserializer<'de>,
193    {
194        #[derive(serde::Deserialize)]
195        #[serde(deny_unknown_fields)]
196        struct Wire {
197            #[serde(default)]
198            model: ModelSpec,
199            #[serde(default)]
200            provider_id: String,
201            #[serde(default)]
202            session_id: Option<String>,
203            #[serde(default)]
204            autonomous: bool,
205            #[serde(default)]
206            max_turns: Option<usize>,
207            #[serde(default)]
208            prompt: crate::PromptLayer,
209        }
210
211        let value = serde_json::Value::deserialize(deserializer)?;
212        if value
213            .as_object()
214            .is_some_and(|object| object.contains_key("provider"))
215        {
216            return Err(serde::de::Error::custom(
217                "legacy serialized provider config is not supported in session state; persist provider_id only",
218            ));
219        }
220        let wire = Wire::deserialize(value).map_err(serde::de::Error::custom)?;
221        Ok(Self {
222            model: wire.model,
223            provider_id: wire.provider_id,
224            session_id: wire.session_id,
225            autonomous: wire.autonomous,
226            max_turns: wire.max_turns,
227            prompt: wire.prompt,
228        })
229    }
230}
231
232/// Runtime-only policy resolved against host-owned live dependencies.
233#[derive(Clone, Debug, Default, PartialEq, Eq)]
234pub struct RuntimeSessionPolicy {
235    pub policy: SessionPolicy,
236    pub binding: ProviderBinding,
237}
238
239impl RuntimeSessionPolicy {
240    pub fn new(policy: SessionPolicy, binding: ProviderBinding) -> Self {
241        Self { policy, binding }
242    }
243
244    pub fn from_provider(
245        policy: SessionPolicy,
246        provider: ProviderHandle,
247    ) -> Result<Self, ProviderResolutionError> {
248        let binding = ProviderBinding::new(policy.recorded_provider_id(), provider)?;
249        Ok(Self { policy, binding })
250    }
251
252    pub fn provider(&self) -> &ProviderHandle {
253        &self.binding.provider
254    }
255
256    pub fn into_policy(self) -> SessionPolicy {
257        self.policy
258    }
259}
260
261impl std::ops::Deref for RuntimeSessionPolicy {
262    type Target = SessionPolicy;
263
264    fn deref(&self) -> &Self::Target {
265        &self.policy
266    }
267}
268
269impl std::ops::DerefMut for RuntimeSessionPolicy {
270    fn deref_mut(&mut self) -> &mut Self::Target {
271        &mut self.policy
272    }
273}
274
275/// Reusable session configuration overlay.
276///
277/// `SessionSpec` is the public configuration shape for callers that want to
278/// describe either a root session or a child session without constructing the
279/// persisted [`SessionPolicy`] directly.
280#[derive(Clone, Debug, PartialEq, Eq)]
281pub struct SessionSpec {
282    inherit: bool,
283    pub provider_id: Option<String>,
284    pub model: Option<ModelSpec>,
285    pub max_turns: Option<Option<usize>>,
286    pub prompt: Option<crate::PromptLayer>,
287}
288
289impl SessionSpec {
290    /// Create an explicit root-style spec. Unset fields resolve from the
291    /// runtime's core defaults.
292    pub fn new() -> Self {
293        Self {
294            inherit: false,
295            provider_id: None,
296            model: None,
297            max_turns: None,
298            prompt: None,
299        }
300    }
301
302    /// Create a parent-relative spec. Unset fields inherit from the live
303    /// parent policy at resolution time.
304    pub fn inherit() -> Self {
305        Self {
306            inherit: true,
307            ..Self::new()
308        }
309    }
310
311    pub fn inherits(&self) -> bool {
312        self.inherit
313    }
314
315    pub fn provider_id(mut self, provider_id: impl Into<String>) -> Self {
316        self.provider_id = Some(provider_id.into());
317        self
318    }
319
320    pub fn model(mut self, model: ModelSpec) -> Self {
321        self.model = Some(model);
322        self
323    }
324
325    pub fn max_turns(mut self, max_turns: usize) -> Self {
326        self.max_turns = Some(Some(max_turns));
327        self
328    }
329
330    pub fn clear_max_turns(mut self) -> Self {
331        self.max_turns = Some(None);
332        self
333    }
334
335    pub fn prompt_layer(mut self, prompt: crate::PromptLayer) -> Self {
336        self.prompt = Some(prompt);
337        self
338    }
339
340    pub fn resolve_against(&self, base: &SessionPolicy) -> SessionPolicy {
341        let mut policy = base.clone();
342        if let Some(provider_id) = self.provider_id.as_ref() {
343            policy.provider_id = provider_id.clone();
344        }
345        if let Some(model) = self.model.as_ref() {
346            policy.model = model.clone();
347        }
348        if let Some(max_turns) = self.max_turns {
349            policy.max_turns = max_turns;
350        }
351        if let Some(prompt) = self.prompt.as_ref() {
352            policy.prompt = prompt.clone();
353        }
354        policy
355    }
356}
357
358impl Default for SessionSpec {
359    fn default() -> Self {
360        Self::new()
361    }
362}
363
364pub(crate) fn transport_stream_events(
365    provider: &ProviderHandle,
366    requested: Option<tokio::sync::mpsc::UnboundedSender<LlmStreamEvent>>,
367) -> Option<LlmEventSender> {
368    if let Some(requested) = requested {
369        return Some(make_stream_event_sender(requested));
370    }
371
372    if provider.requires_streaming() {
373        let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<LlmStreamEvent>();
374        drop(rx);
375        Some(make_stream_event_sender(tx))
376    } else {
377        None
378    }
379}
380
381fn make_stream_event_sender(
382    tx: tokio::sync::mpsc::UnboundedSender<LlmStreamEvent>,
383) -> LlmEventSender {
384    LlmEventSender::new(move |event| {
385        let _ = tx.send(event);
386    })
387}
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392
393    #[test]
394    fn protocol_event_writes_tagged_payload() {
395        let event = ProtocolEvent::typed("test_protocol", serde_json::json!({ "value": 42 }))
396            .expect("typed event");
397        let serialized = serde_json::to_value(event).expect("serialize");
398        assert_eq!(serialized["plugin_id"], "test_protocol");
399        assert!(serialized.get("payload").is_some());
400    }
401
402    #[test]
403    fn session_policy_rejects_legacy_provider_config() {
404        let err = serde_json::from_value::<SessionPolicy>(serde_json::json!({
405            "model": {},
406            "provider": {
407                "type": "openai",
408                "api_key": "must-not-load"
409            }
410        }))
411        .expect_err("legacy provider config must fail");
412
413        assert!(
414            err.to_string()
415                .contains("legacy serialized provider config is not supported")
416        );
417    }
418
419    #[test]
420    fn session_policy_serializes_provider_id_without_provider_handle() {
421        let policy = SessionPolicy {
422            provider_id: "mock-provider".to_string(),
423            model: ModelSpec::from_token_limits("mock-model", None, 200_000, None)
424                .expect("valid test model"),
425            ..SessionPolicy::default()
426        };
427
428        let value = serde_json::to_value(&policy).expect("serialize policy");
429
430        assert_eq!(value["provider_id"], "mock-provider");
431        assert!(value.get("provider").is_none());
432    }
433}