Skip to main content

lash/
mode.rs

1use crate::support::*;
2
3#[derive(
4    Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize,
5)]
6pub struct ModeId(String);
7
8impl ModeId {
9    pub fn new(mode: impl Into<String>) -> Self {
10        Self(mode.into())
11    }
12
13    pub fn standard() -> Self {
14        Self("standard".to_string())
15    }
16
17    pub fn rlm() -> Self {
18        Self("rlm".to_string())
19    }
20
21    pub fn as_str(&self) -> &str {
22        &self.0
23    }
24}
25
26impl fmt::Display for ModeId {
27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28        f.write_str(self.as_str())
29    }
30}
31
32/// Semantic mode preset installed on a [`LashCore`].
33#[derive(Clone)]
34pub struct ModePreset {
35    pub(crate) mode_id: ModeId,
36    pub(crate) factory: Arc<dyn PluginFactory>,
37}
38
39impl ModePreset {
40    pub fn standard() -> Self {
41        Self {
42            mode_id: ModeId::standard(),
43            factory: Arc::new(lash_protocol_standard::StandardProtocolPluginFactory::new()),
44        }
45    }
46
47    pub fn rlm() -> Self {
48        Self {
49            mode_id: ModeId::rlm(),
50            factory: Arc::new(lash_protocol_rlm::RlmProtocolPluginFactory::new(
51                rlm_preset_config(lash_protocol_rlm::RlmProtocolPluginConfig::default()),
52            )),
53        }
54    }
55
56    pub fn rlm_with_config(config: lash_protocol_rlm::RlmProtocolPluginConfig) -> Self {
57        Self {
58            mode_id: ModeId::rlm(),
59            factory: Arc::new(lash_protocol_rlm::RlmProtocolPluginFactory::new(
60                rlm_preset_config(config),
61            )),
62        }
63    }
64
65    pub fn rlm_with_projection_resolver(
66        projection_resolver: Arc<dyn lash_protocol_rlm::ProjectionResolver>,
67    ) -> Self {
68        Self::rlm_with_config_and_projection_resolver(
69            lash_protocol_rlm::RlmProtocolPluginConfig::default(),
70            projection_resolver,
71        )
72    }
73
74    pub fn rlm_with_config_and_projection_resolver(
75        config: lash_protocol_rlm::RlmProtocolPluginConfig,
76        projection_resolver: Arc<dyn lash_protocol_rlm::ProjectionResolver>,
77    ) -> Self {
78        Self {
79            mode_id: ModeId::rlm(),
80            factory: Arc::new(
81                lash_protocol_rlm::RlmProtocolPluginFactory::new(rlm_preset_config(config))
82                    .with_projection_resolver(projection_resolver),
83            ),
84        }
85    }
86
87    pub fn mode_id(&self) -> &ModeId {
88        &self.mode_id
89    }
90}
91
92pub trait RlmTurnBuilderExt: Sized {
93    fn require_submit(self) -> Result<Self>;
94    fn require_submit_schema(self, schema: serde_json::Value) -> Result<Self>;
95    fn allow_prose_or_submit(self) -> Result<Self>;
96}
97
98impl RlmTurnBuilderExt for TurnBuilder {
99    fn require_submit(self) -> Result<Self> {
100        rlm_termination(
101            self,
102            lash_rlm_types::RlmTermination::SubmitRequired { schema: None },
103        )
104    }
105
106    fn require_submit_schema(self, schema: serde_json::Value) -> Result<Self> {
107        rlm_termination(
108            self,
109            lash_rlm_types::RlmTermination::SubmitRequired {
110                schema: Some(schema),
111            },
112        )
113    }
114
115    fn allow_prose_or_submit(self) -> Result<Self> {
116        rlm_termination(self, lash_rlm_types::RlmTermination::ProseOrSubmit)
117    }
118}
119
120pub trait RlmSessionBuilderExt: Sized {
121    fn final_answer_format(self, format: lash_rlm_types::RlmFinalAnswerFormat) -> Self;
122}
123
124impl RlmSessionBuilderExt for SessionBuilder {
125    fn final_answer_format(mut self, format: lash_rlm_types::RlmFinalAnswerFormat) -> Self {
126        self.rlm_final_answer_format = Some(format);
127        self
128    }
129}
130
131fn rlm_termination(
132    mut builder: TurnBuilder,
133    termination: lash_rlm_types::RlmTermination,
134) -> Result<TurnBuilder> {
135    let override_options = ProtocolTurnOptions::typed(lash_rlm_types::RlmCreateExtras {
136        termination,
137        final_answer_format: None,
138    })?;
139    let options = builder
140        .protocol_turn_options
141        .as_ref()
142        .map(|current| current.merged_with_override(&override_options))
143        .unwrap_or(override_options);
144    builder.protocol_turn_options = Some(options);
145    Ok(builder)
146}
147
148fn rlm_preset_config(
149    config: lash_protocol_rlm::RlmProtocolPluginConfig,
150) -> lash_protocol_rlm::RlmProtocolPluginConfig {
151    let language_features = config.lashlang_language_features.with_label_annotations();
152    config.with_lashlang_language_features(language_features)
153}