Skip to main content

telltale_language/extensions/
timeout.rs

1//! Example timeout extension for choreographic protocols
2//!
3//! This demonstrates how to create a complete extension that adds timeout
4//! functionality to choreographic protocols.
5
6use super::{
7    CodegenContext, ExtensionRegistry, ExtensionValidationError, GrammarExtension, ParseContext,
8    ParseError, ProjectionContext, ProtocolExtension, StatementParser,
9};
10use crate::ast::{LocalType, Role};
11use crate::compiler::projection::ProjectionError;
12use std::any::{Any, TypeId};
13use std::time::Duration;
14
15/// Grammar extension that adds timeout syntax
16#[derive(Debug)]
17pub struct TimeoutGrammarExtension;
18
19impl GrammarExtension for TimeoutGrammarExtension {
20    fn grammar_rules(&self) -> &'static str {
21        r#"
22timeout_ext_stmt = { "timeout" ~ timeout_ext_duration ~ timeout_ext_roles ~ "{" ~ protocol_body ~ "}" }
23timeout_ext_duration = { integer ~ timeout_ext_time_unit? }
24timeout_ext_time_unit = { "ms" | "s" | "m" | "h" }
25timeout_ext_roles = { "(" ~ role_list ~ ")" | role_ref }
26"#
27    }
28
29    fn statement_rules(&self) -> Vec<&'static str> {
30        vec!["timeout_ext_stmt"]
31    }
32
33    fn priority(&self) -> u32 {
34        200 // Higher than default to take precedence
35    }
36
37    fn extension_id(&self) -> &'static str {
38        "timeout"
39    }
40}
41
42/// Statement parser for timeout constructs
43#[derive(Debug)]
44pub struct TimeoutStatementParser;
45
46impl StatementParser for TimeoutStatementParser {
47    fn can_parse(&self, rule_name: &str) -> bool {
48        rule_name == "timeout_ext_stmt"
49    }
50
51    fn supported_rules(&self) -> Vec<String> {
52        vec!["timeout_ext_stmt".to_string()]
53    }
54
55    fn parse_statement(
56        &self,
57        rule_name: &str,
58        _content: &str,
59        context: &ParseContext,
60    ) -> Result<Box<dyn ProtocolExtension>, ParseError> {
61        if rule_name != "timeout_ext_stmt" {
62            return Err(ParseError::InvalidSyntax {
63                details: format!("Expected timeout_ext_stmt, got {}", rule_name),
64            });
65        }
66
67        // Parse the timeout statement
68        // This is a simplified parser - in practice, you'd use the Pest parse tree
69        let timeout_protocol = self.parse_timeout_content(_content, context)?;
70        Ok(Box::new(timeout_protocol))
71    }
72}
73
74impl TimeoutStatementParser {
75    fn parse_timeout_content(
76        &self,
77        content: &str,
78        context: &ParseContext,
79    ) -> Result<TimeoutProtocol, ParseError> {
80        // Simplified parsing - extract duration and roles (production would use Pest tree)
81        let duration_ms = self.extract_duration(content)?;
82        let roles = self.extract_roles(content, context)?;
83
84        // Body defaults to End; full implementation would recursively parse
85        Ok(TimeoutProtocol {
86            duration: Duration::from_millis(duration_ms),
87            role_names: roles.iter().map(|r| r.name().to_string()).collect(),
88            body_repr: "End".to_string(),
89        })
90    }
91
92    fn extract_duration(&self, content: &str) -> Result<u64, ParseError> {
93        // Simplified duration extraction
94        // Look for numeric patterns
95        let duration_str = content
96            .split_whitespace()
97            .find(|s| s.chars().all(|c| c.is_ascii_digit()))
98            .ok_or_else(|| ParseError::InvalidSyntax {
99                details: "Could not find timeout duration".to_string(),
100            })?;
101
102        duration_str.parse().map_err(|_| ParseError::InvalidSyntax {
103            details: "Invalid timeout duration format".to_string(),
104        })
105    }
106
107    fn extract_roles(
108        &self,
109        _content: &str,
110        context: &ParseContext,
111    ) -> Result<Vec<Role>, ParseError> {
112        // Simplified role extraction; returns all declared roles
113        // (production would properly parse role references from content)
114        Ok(context.declared_roles.to_vec())
115    }
116}
117
118/// Protocol extension implementation for timeouts
119#[derive(Debug, Clone)]
120pub struct TimeoutProtocol {
121    pub duration: Duration,
122    pub role_names: Vec<String>, // Use simple strings instead of Role structs
123    // Note: Storing the full Protocol AST would require fixing Send + Sync issues
124    // For this example, we store a simplified representation
125    pub body_repr: String,
126}
127
128impl ProtocolExtension for TimeoutProtocol {
129    fn type_name(&self) -> &'static str {
130        "TimeoutProtocol"
131    }
132
133    fn mentions_role(&self, role: &Role) -> bool {
134        self.role_names
135            .iter()
136            .any(|name| name == &role.name().to_string())
137    }
138
139    fn validate(&self, all_roles: &[Role]) -> Result<(), ExtensionValidationError> {
140        // Validate that all mentioned roles are declared
141        for role_name in &self.role_names {
142            if !all_roles.iter().any(|r| &r.name().to_string() == role_name) {
143                return Err(ExtensionValidationError::UndeclaredRole {
144                    role: role_name.clone(),
145                });
146            }
147        }
148
149        // Validate duration is reasonable
150        if self.duration.is_zero() {
151            return Err(ExtensionValidationError::InvalidStructure {
152                reason: "Timeout duration cannot be zero".to_string(),
153            });
154        }
155
156        if self.duration > Duration::from_secs(3600) {
157            return Err(ExtensionValidationError::InvalidStructure {
158                reason: "Timeout duration too long (max 1 hour)".to_string(),
159            });
160        }
161
162        Ok(())
163    }
164
165    fn project(
166        &self,
167        role: &Role,
168        _context: &ProjectionContext,
169    ) -> Result<LocalType, ProjectionError> {
170        if self
171            .role_names
172            .iter()
173            .any(|name| name == &role.name().to_string())
174        {
175            // This role participates in the timeout; body defaults to End
176            // (full implementation would project the protocol body recursively)
177            Ok(LocalType::Timeout {
178                duration: self.duration,
179                body: Box::new(LocalType::End),
180                on_timeout: Box::new(LocalType::End),
181                on_cancel: None,
182            })
183        } else {
184            // This role doesn't participate in timeout, return End
185            Ok(LocalType::End)
186        }
187    }
188
189    fn generate_code(&self, _context: &CodegenContext) -> proc_macro2::TokenStream {
190        let duration_ms = u64::try_from(self.duration.as_millis()).unwrap_or(u64::MAX);
191        let _role_names = &self.role_names;
192
193        quote::quote! {
194            // Generate timeout wrapper code
195            .with_timeout(
196                Duration::from_millis(#duration_ms),
197                // Timeout applies to these roles: #(#role_names),*
198            )
199        }
200    }
201
202    fn as_any(&self) -> &dyn Any {
203        self
204    }
205
206    fn as_any_mut(&mut self) -> &mut dyn Any {
207        self
208    }
209
210    fn type_id(&self) -> TypeId {
211        TypeId::of::<Self>()
212    }
213
214    fn clone_box(&self) -> Box<dyn ProtocolExtension> {
215        Box::new(self.clone())
216    }
217}
218
219/// Convenience function to register the timeout extension.
220///
221/// # Errors
222///
223/// Returns an error if there's a priority conflict with an existing extension.
224pub fn register_timeout_extension(
225    registry: &mut ExtensionRegistry,
226) -> Result<(), crate::extensions::ParseError> {
227    registry.register_grammar(TimeoutGrammarExtension)?;
228    registry.register_parser(TimeoutStatementParser, "timeout".to_string());
229    Ok(())
230}
231
232/// Extend LocalType to support timeout
233impl LocalType {
234    pub fn timeout(duration: Duration, body: LocalType) -> Self {
235        Self::Timeout {
236            duration,
237            body: Box::new(body),
238            on_timeout: Box::new(LocalType::End),
239            on_cancel: None,
240        }
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247
248    #[test]
249    fn test_timeout_grammar_extension() {
250        let ext = TimeoutGrammarExtension;
251        assert_eq!(ext.extension_id(), "timeout");
252        assert!(ext.statement_rules().contains(&"timeout_ext_stmt"));
253        assert!(ext.grammar_rules().contains("timeout_ext_stmt"));
254    }
255
256    #[test]
257    fn test_timeout_statement_parser() {
258        let parser = TimeoutStatementParser;
259        assert!(parser.can_parse("timeout_ext_stmt"));
260        assert!(!parser.can_parse("unknown_stmt"));
261    }
262
263    #[test]
264    fn test_timeout_protocol() {
265        let timeout_protocol = TimeoutProtocol {
266            duration: Duration::from_millis(5000),
267            role_names: vec!["Alice".to_string()],
268            body_repr: "End".to_string(),
269        };
270
271        assert_eq!(timeout_protocol.type_name(), "TimeoutProtocol");
272
273        use proc_macro2::Span;
274        let alice = Role::new(proc_macro2::Ident::new("Alice", Span::call_site())).unwrap();
275        let bob = Role::new(proc_macro2::Ident::new("Bob", Span::call_site())).unwrap();
276
277        assert!(timeout_protocol.mentions_role(&alice));
278        assert!(!timeout_protocol.mentions_role(&bob));
279    }
280
281    #[test]
282    fn test_timeout_validation() {
283        use proc_macro2::Span;
284        let roles = vec![Role::new(proc_macro2::Ident::new("Alice", Span::call_site())).unwrap()];
285
286        let valid_timeout = TimeoutProtocol {
287            duration: Duration::from_millis(5000),
288            role_names: roles.iter().map(|r| r.name().to_string()).collect(),
289            body_repr: "End".to_string(),
290        };
291
292        assert!(valid_timeout.validate(&roles).is_ok());
293
294        // Test invalid duration
295        let invalid_timeout = TimeoutProtocol {
296            duration: Duration::ZERO,
297            role_names: roles.iter().map(|r| r.name().to_string()).collect(),
298            body_repr: "End".to_string(),
299        };
300
301        assert!(invalid_timeout.validate(&roles).is_err());
302    }
303
304    #[test]
305    fn test_extension_registration() {
306        let mut registry = ExtensionRegistry::new();
307        register_timeout_extension(&mut registry).expect("extension should register");
308
309        assert!(registry.can_handle("timeout_ext_stmt"));
310    }
311}