use std::sync::Arc;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use super::{SessionAppendNode, SessionCreateRequest};
use crate::runtime::PersistedSessionState;
use crate::{
ExecRequest, ExecResponse, ExecutionMode, LlmRequest, ModeExecutionContext, PromptUsage,
SessionReadView, ToolContract, ToolManifest, ToolResult,
};
#[async_trait::async_trait]
pub trait ModeSessionPlugin: Send + Sync {
async fn initialize_session(
&self,
_ctx: ModeSessionContext<'_>,
) -> Result<(), crate::SessionError> {
Ok(())
}
async fn restore_session(
&self,
_ctx: ModeSessionContext<'_>,
_state: &PersistedSessionState,
) -> Result<(), crate::SessionError> {
Ok(())
}
async fn append_session_nodes(
&self,
_ctx: ModeSessionContext<'_>,
_nodes: &[SessionAppendNode],
) -> Result<(), crate::SessionError> {
Ok(())
}
async fn apply_session_extension(
&self,
_extension: crate::ModeSessionExtensionHandle,
) -> Result<(), crate::SessionError> {
Err(crate::SessionError::Protocol(
"execution mode does not accept session extensions".to_string(),
))
}
async fn validate_turn_extension(
&self,
_extension: &crate::ModeTurnExtensionHandle,
) -> Result<(), crate::SessionError> {
Ok(())
}
async fn execute_code(
&self,
_ctx: ModeExecutionContext,
_request: ExecRequest,
) -> Result<ExecResponse, crate::SessionError> {
Err(crate::SessionError::RlmUnavailable)
}
fn execution_state_dirty(&self) -> bool {
false
}
async fn snapshot_execution_state(
&self,
_ctx: ModeSessionContext<'_>,
) -> Result<Option<Vec<u8>>, crate::SessionError> {
Ok(None)
}
async fn restore_execution_state(
&self,
_ctx: ModeSessionContext<'_>,
_data: &[u8],
) -> Result<(), crate::SessionError> {
Ok(())
}
fn configure_runtime_from_request(
&self,
_ctx: ModeRuntimeContext<'_>,
_request: &SessionCreateRequest,
) {
}
async fn before_llm_call(
&self,
_ctx: ModeBeforeLlmCallContext,
_request: &LlmRequest,
) -> Result<Option<ModeLlmCallAction>, crate::PluginError> {
Ok(None)
}
}
pub struct ModeSessionContext<'a> {
session_id: &'a str,
}
impl<'a> ModeSessionContext<'a> {
pub(crate) fn new(_session: &'a mut crate::Session, session_id: &'a str) -> Self {
Self { session_id }
}
pub fn session_id(&self) -> &str {
self.session_id
}
}
pub struct ModeBeforeLlmCallContext {
pub session_id: String,
pub host: Arc<dyn crate::plugin::RuntimeSessionHost>,
pub state: SessionReadView,
pub latest_prompt_usage: Option<PromptUsage>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ModeLlmCallAction {
Handoff { session_id: String },
}
pub struct ModeRuntimeContext<'a> {
runtime: &'a mut crate::runtime::LashRuntime,
}
impl<'a> ModeRuntimeContext<'a> {
pub(crate) fn new(runtime: &'a mut crate::runtime::LashRuntime) -> Self {
Self { runtime }
}
pub fn set_mode_turn_options(&mut self, options: crate::ModeTurnOptions) {
self.runtime.set_mode_turn_options(options);
}
}
#[async_trait::async_trait]
pub trait ModeNativeToolsPlugin: Send + Sync {
fn tool_manifests(&self) -> Vec<ToolManifest>;
fn resolve_manifest(&self, name: &str) -> Option<ToolManifest> {
self.tool_manifests()
.into_iter()
.find(|manifest| manifest.name == name)
}
fn resolve_contract(&self, name: &str) -> Option<Arc<ToolContract>>;
async fn execute(
&self,
context: &crate::tool_dispatch::ToolDispatchContext,
name: &str,
args: &serde_json::Value,
progress: Option<&crate::ProgressSender>,
) -> Option<ToolResult>;
}
pub trait ModeProtocolDriverPlugin: Send + Sync {
fn mode_id(&self) -> &str;
fn build_preamble(&self, input: crate::ModeBuildInput) -> crate::ModePreamble;
}
#[derive(Clone, Debug, Serialize)]
pub struct ModeExtras {
pub mode_id: ExecutionMode,
#[serde(default)]
pub payload: serde_json::Value,
}
impl Default for ModeExtras {
fn default() -> Self {
Self::empty(ExecutionMode::standard())
}
}
impl ModeExtras {
pub fn empty(mode_id: ExecutionMode) -> Self {
Self {
mode_id,
payload: serde_json::Value::Object(serde_json::Map::new()),
}
}
pub fn typed<T>(mode_id: ExecutionMode, extras: T) -> Result<Self, serde_json::Error>
where
T: Serialize,
{
Ok(Self {
mode_id,
payload: serde_json::to_value(extras)?,
})
}
pub fn decode<T>(&self, expected_mode: &ExecutionMode) -> Result<Option<T>, serde_json::Error>
where
T: DeserializeOwned,
{
if &self.mode_id != expected_mode {
return Ok(None);
}
serde_json::from_value(self.payload.clone()).map(Some)
}
}
impl<'de> Deserialize<'de> for ModeExtras {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = serde_json::Value::deserialize(deserializer)?;
if let Some(object) = value.as_object() {
if let (Some(mode_id), Some(payload)) = (object.get("mode_id"), object.get("payload")) {
let mode_id = ExecutionMode::deserialize(mode_id.clone())
.map_err(serde::de::Error::custom)?;
return Ok(Self {
mode_id,
payload: payload.clone(),
});
}
if let Some(mode) = object.get("mode").and_then(serde_json::Value::as_str) {
let mut payload = object.clone();
payload.remove("mode");
return Ok(Self {
mode_id: ExecutionMode::new(mode),
payload: serde_json::Value::Object(payload),
});
}
}
Err(serde::de::Error::custom("invalid mode extras payload"))
}
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct StandardCreateExtras {}