1pub mod context;
2pub use lash_sansio::session_model::message;
3pub use lash_sansio::session_model::prompt;
4
5use std::sync::Arc;
6use tokio::sync::mpsc;
7
8use crate::llm::types::{LlmEventSender, LlmStreamEvent};
9use crate::plugin::PluginMessage;
10use crate::provider::ProviderHandle;
11use crate::{ExecutionMode, StandardContextApproach};
12
13pub use lash_sansio::session_model::{
14 ConversationRecord, ErrorEnvelope, MAIN_AGENT_INTRO, Message, MessageRole, Part, PartKind,
15 PromptBuiltin, PromptSlot, PromptTemplate, PromptTemplateEntry, PromptTemplateSection,
16 PruneState, SessionEvent, StateSnapshotEvent, TokenUsage, ToolEvent,
17 TurnTerminationPolicyState, default_prompt_template, format_tool_output_content,
18 format_tool_result_content, fresh_message_id, make_error_envelope, make_error_event,
19 reassign_part_ids, render_prompt, render_transcript_prompt, shared_parts,
20};
21
22#[derive(Clone, Debug, PartialEq)]
23pub struct ModeEvent {
24 pub mode_id: ExecutionMode,
25 pub payload: serde_json::Value,
26}
27
28impl ModeEvent {
29 pub fn typed<T>(mode_id: ExecutionMode, event: T) -> Result<Self, serde_json::Error>
30 where
31 T: serde::Serialize,
32 {
33 Ok(Self {
34 mode_id,
35 payload: serde_json::to_value(event)?,
36 })
37 }
38
39 pub fn decode<T>(&self, expected_mode: &ExecutionMode) -> Result<Option<T>, serde_json::Error>
40 where
41 T: for<'de> serde::Deserialize<'de>,
42 {
43 if &self.mode_id != expected_mode {
44 return Ok(None);
45 }
46 serde_json::from_value(self.payload.clone()).map(Some)
47 }
48}
49
50impl serde::Serialize for ModeEvent {
51 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
52 where
53 S: serde::Serializer,
54 {
55 #[derive(serde::Serialize)]
56 struct Tagged<'a> {
57 mode_id: &'a ExecutionMode,
58 payload: &'a serde_json::Value,
59 }
60 Tagged {
61 mode_id: &self.mode_id,
62 payload: &self.payload,
63 }
64 .serialize(serializer)
65 }
66}
67
68impl<'de> serde::Deserialize<'de> for ModeEvent {
69 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
70 where
71 D: serde::Deserializer<'de>,
72 {
73 let value = serde_json::Value::deserialize(deserializer)?;
74 if let Some(object) = value.as_object()
75 && let (Some(mode_id), Some(payload)) = (object.get("mode_id"), object.get("payload"))
76 {
77 let mode_id =
78 ExecutionMode::deserialize(mode_id.clone()).map_err(serde::de::Error::custom)?;
79 return Ok(Self {
80 mode_id,
81 payload: payload.clone(),
82 });
83 }
84 Err(serde::de::Error::custom(
85 "mode events must be tagged with mode_id and payload",
86 ))
87 }
88}
89
90pub type SessionEventRecord = lash_sansio::session_model::SessionEventRecord<ModeEvent>;
91
92pub(crate) async fn send_event(tx: &mpsc::Sender<SessionEvent>, event: SessionEvent) {
94 if !tx.is_closed() {
95 let _ = tx.send(event).await;
96 }
97}
98
99pub(crate) fn plugin_message_to_message(plugin_message: &PluginMessage) -> Message {
100 let message_id = fresh_message_id();
101 let mut parts = if plugin_message.parts.is_empty() {
102 vec![Part {
103 id: format!("{message_id}.p0"),
104 kind: PartKind::Text,
105 content: plugin_message.content.clone(),
106 attachment: None,
107 tool_call_id: None,
108 tool_name: None,
109 tool_replay: None,
110 prune_state: PruneState::Intact,
111 reasoning_meta: None,
112 response_meta: None,
113 }]
114 } else {
115 plugin_message.parts.clone()
116 };
117 reassign_part_ids(&message_id, &mut parts);
118 Message {
119 id: message_id,
120 role: plugin_message.role,
121 parts: Arc::new(parts),
122 origin: Some(crate::MessageOrigin::Plugin {
123 plugin_id: "plugin".to_string(),
124 transient: false,
125 }),
126 }
127}
128
129#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
136pub struct SessionPolicy {
137 pub model: String,
138 pub provider: ProviderHandle,
139 pub max_context_tokens: Option<usize>,
140 pub model_variant: Option<String>,
141 pub session_id: Option<String>,
142 #[serde(default)]
143 pub autonomous: bool,
144 pub max_turns: Option<usize>,
145 pub execution_mode: ExecutionMode,
146 #[serde(default, skip_serializing_if = "Option::is_none")]
147 pub standard_context_approach: Option<StandardContextApproach>,
148 #[serde(default, skip_serializing_if = "crate::PromptLayer::is_empty")]
149 pub prompt: crate::PromptLayer,
150}
151
152impl SessionPolicy {
153 pub fn normalize_for_execution_mode(&mut self) {
155 if self.execution_mode != ExecutionMode::standard() {
156 self.standard_context_approach = None;
157 }
158 }
159
160 pub fn normalized_for_execution_mode(mut self) -> Self {
161 self.normalize_for_execution_mode();
162 self
163 }
164}
165
166#[derive(Clone, Debug, PartialEq, Eq)]
172pub struct SessionSpec {
173 inherit: bool,
174 pub provider: Option<ProviderHandle>,
175 pub model: Option<String>,
176 pub model_variant: Option<Option<String>>,
177 pub execution_mode: Option<ExecutionMode>,
178 pub max_context_tokens: Option<usize>,
179 pub max_turns: Option<Option<usize>>,
180 pub prompt: Option<crate::PromptLayer>,
181}
182
183impl SessionSpec {
184 pub fn new() -> Self {
187 Self {
188 inherit: false,
189 provider: None,
190 model: None,
191 model_variant: None,
192 execution_mode: None,
193 max_context_tokens: None,
194 max_turns: None,
195 prompt: None,
196 }
197 }
198
199 pub fn inherit() -> Self {
202 Self {
203 inherit: true,
204 ..Self::new()
205 }
206 }
207
208 pub fn inherits(&self) -> bool {
209 self.inherit
210 }
211
212 pub fn provider(mut self, provider: ProviderHandle) -> Self {
213 self.provider = Some(provider);
214 self
215 }
216
217 pub fn model(mut self, model: impl Into<String>, variant: Option<String>) -> Self {
218 self.model = Some(model.into());
219 self.model_variant = Some(variant);
220 self
221 }
222
223 pub fn model_variant(mut self, variant: impl Into<String>) -> Self {
224 self.model_variant = Some(Some(variant.into()));
225 self
226 }
227
228 pub fn clear_model_variant(mut self) -> Self {
229 self.model_variant = Some(None);
230 self
231 }
232
233 pub fn mode(mut self, mode: ExecutionMode) -> Self {
234 self.execution_mode = Some(mode);
235 self
236 }
237
238 pub fn max_context_tokens(mut self, max_context_tokens: usize) -> Self {
239 self.max_context_tokens = Some(max_context_tokens);
240 self
241 }
242
243 pub fn max_turns(mut self, max_turns: usize) -> Self {
244 self.max_turns = Some(Some(max_turns));
245 self
246 }
247
248 pub fn clear_max_turns(mut self) -> Self {
249 self.max_turns = Some(None);
250 self
251 }
252
253 pub fn prompt_layer(mut self, prompt: crate::PromptLayer) -> Self {
254 self.prompt = Some(prompt);
255 self
256 }
257
258 pub fn resolve_against(&self, base: &SessionPolicy) -> SessionPolicy {
259 let mut policy = base.clone();
260 if let Some(provider) = self.provider.as_ref() {
261 policy.provider = provider.clone();
262 if self.model.is_none() {
263 let model = provider.default_model().to_string();
264 policy.model_variant = provider.default_model_variant(&model).map(str::to_string);
265 policy.model = model;
266 }
267 }
268 if let Some(model) = self.model.as_ref() {
269 policy.model = model.clone();
270 }
271 if let Some(model_variant) = self.model_variant.as_ref() {
272 policy.model_variant = model_variant.clone();
273 }
274 if let Some(max_context_tokens) = self.max_context_tokens {
275 policy.max_context_tokens = Some(max_context_tokens);
276 }
277 if let Some(max_turns) = self.max_turns {
278 policy.max_turns = max_turns;
279 }
280 if let Some(execution_mode) = self.execution_mode.as_ref() {
281 policy.execution_mode = execution_mode.clone();
282 }
283 if let Some(prompt) = self.prompt.as_ref() {
284 policy.prompt = prompt.clone();
285 }
286 policy
287 }
288}
289
290impl Default for SessionSpec {
291 fn default() -> Self {
292 Self::new()
293 }
294}
295
296impl Default for SessionPolicy {
297 fn default() -> Self {
298 Self {
299 model: String::new(),
300 provider: ProviderHandle::default(),
301 max_context_tokens: None,
302 model_variant: None,
303 session_id: None,
304 autonomous: false,
305 max_turns: None,
306 execution_mode: ExecutionMode::standard(),
307 standard_context_approach: Some(StandardContextApproach::default()),
308 prompt: crate::PromptLayer::default(),
309 }
310 }
311}
312
313pub(crate) fn transport_stream_events(
314 provider: &ProviderHandle,
315 requested: Option<tokio::sync::mpsc::UnboundedSender<LlmStreamEvent>>,
316) -> Option<LlmEventSender> {
317 if let Some(requested) = requested {
318 return Some(make_stream_event_sender(requested));
319 }
320
321 if provider.requires_streaming() {
322 let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<LlmStreamEvent>();
323 drop(rx);
324 Some(make_stream_event_sender(tx))
325 } else {
326 None
327 }
328}
329
330fn make_stream_event_sender(
331 tx: tokio::sync::mpsc::UnboundedSender<LlmStreamEvent>,
332) -> LlmEventSender {
333 LlmEventSender::new(move |event| {
334 let _ = tx.send(event);
335 })
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341
342 #[test]
343 fn mode_event_writes_tagged_payload() {
344 let event = ModeEvent::typed(
345 ExecutionMode::new("test"),
346 serde_json::json!({ "value": 42 }),
347 )
348 .expect("typed event");
349 let serialized = serde_json::to_value(event).expect("serialize");
350 assert_eq!(serialized["mode_id"], "test");
351 assert!(serialized.get("payload").is_some());
352 }
353}