Skip to main content

cruster/
entity.rs

1use crate::error::ClusterError;
2use crate::sharding::Sharding;
3use crate::snowflake::SnowflakeGenerator;
4use crate::types::{EntityAddress, EntityId, EntityType, RunnerAddress};
5use async_trait::async_trait;
6use std::collections::HashMap;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio_stream::Stream;
11
12/// Context provided to entity instances when they are spawned.
13#[derive(Clone)]
14pub struct EntityContext {
15    /// The address of this entity instance.
16    pub address: EntityAddress,
17    /// The address of the runner hosting this entity.
18    pub runner_address: RunnerAddress,
19    /// Shared snowflake ID generator.
20    pub snowflake: Arc<SnowflakeGenerator>,
21    /// Cancellation token for this entity's lifetime.
22    pub cancellation: tokio_util::sync::CancellationToken,
23    /// Optional key-value storage for persisted entity state.
24    ///
25    /// When present, entity macros with `#[state(Type)]` will load
26    /// state from this storage on spawn and save after every `#[activity]` handler
27    /// call. The storage key is `"entity/{entity_type}/{entity_id}/state"`.
28    pub state_storage: Option<Arc<dyn crate::__internal::WorkflowStorage>>,
29    /// Optional workflow engine for durable context support.
30    ///
31    /// When present, entity methods with `&DurableContext` parameters can use
32    /// durable sleep, await_deferred, resolve_deferred, and on_interrupt operations.
33    /// The macro generates code to construct a `DurableContext` from this engine.
34    pub workflow_engine: Option<Arc<dyn crate::__internal::WorkflowEngine>>,
35    /// Optional sharding interface for inter-entity communication.
36    ///
37    /// When present, entities can create clients to send messages to other entities
38    /// or to themselves, including scheduled messages via `notify_at`.
39    pub sharding: Option<Arc<dyn Sharding>>,
40    /// Optional message storage for activity journaling.
41    ///
42    /// When present, activities called from within `#[workflow]` methods are
43    /// journaled: their results are cached in `MessageStorage` so that on
44    /// crash-recovery replay the cached result is returned instead of
45    /// re-executing the activity body.
46    pub message_storage: Option<Arc<dyn crate::message_storage::MessageStorage>>,
47}
48
49/// Defines an entity type with its RPCs and behavior.
50///
51/// Users implement this trait to define an entity. Each entity type has a unique
52/// name and a factory method (`spawn`) that creates handler instances for
53/// individual entity IDs.
54#[async_trait]
55pub trait Entity: Send + Sync + 'static {
56    /// Unique type name for this entity (e.g., "User", "Order").
57    fn entity_type(&self) -> EntityType;
58
59    /// Shard group this entity belongs to. Default: "default".
60    fn shard_group(&self) -> &str {
61        "default"
62    }
63
64    /// Resolve shard group from entity ID. Override for custom routing.
65    fn shard_group_for(&self, _entity_id: &EntityId) -> &str {
66        self.shard_group()
67    }
68
69    /// Maximum idle time before reaping. None = use config default.
70    fn max_idle_time(&self) -> Option<Duration> {
71        None
72    }
73
74    /// Mailbox capacity. None = use config default.
75    fn mailbox_capacity(&self) -> Option<usize> {
76        None
77    }
78
79    /// Maximum number of concurrent requests this entity can handle.
80    /// `None` = use config default (`entity_max_concurrent_requests`).
81    /// `Some(0)` = unbounded concurrency. `Some(1)` = serial (default behavior).
82    /// `Some(n)` = at most `n` concurrent requests.
83    ///
84    /// When concurrency > 1, the handler must be safe for concurrent access
85    /// (which is guaranteed by the `Send + Sync` bound on `EntityHandler`).
86    /// Crash recovery under concurrency > 1 will replay ALL in-flight requests
87    /// against the new handler.
88    fn concurrency(&self) -> Option<usize> {
89        None
90    }
91
92    /// Create a handler instance for the given entity address.
93    /// The returned handler lives for the lifetime of the entity instance.
94    async fn spawn(&self, ctx: EntityContext) -> Result<Box<dyn EntityHandler>, ClusterError>;
95}
96
97/// Handles incoming RPCs for a specific entity instance.
98///
99/// Each entity instance has one handler that processes all incoming messages.
100/// The handler is created by `Entity::spawn` and lives until the entity is
101/// reaped (idle timeout) or the runner shuts down.
102#[async_trait]
103pub trait EntityHandler: Send + Sync {
104    /// Handle an incoming request. Returns serialized response bytes.
105    async fn handle_request(
106        &self,
107        tag: &str,
108        payload: &[u8],
109        headers: &HashMap<String, String>,
110    ) -> Result<Vec<u8>, ClusterError>;
111
112    /// Handle a streaming request. Returns a stream of serialized chunks.
113    ///
114    /// Default implementation wraps `handle_request` as a single-item stream.
115    async fn handle_stream(
116        &self,
117        tag: &str,
118        payload: &[u8],
119        headers: &HashMap<String, String>,
120    ) -> Result<Pin<Box<dyn Stream<Item = Result<Vec<u8>, ClusterError>> + Send>>, ClusterError>
121    {
122        let result = self.handle_request(tag, payload, headers).await?;
123        Ok(Box::pin(tokio_stream::once(Ok(result))))
124    }
125
126    /// Called when the entity is about to be reaped (idle timeout).
127    /// Return true to keep alive, false to allow reaping.
128    async fn on_idle(&self) -> bool {
129        false
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136    use crate::types::{EntityId, EntityType, ShardId};
137
138    /// A mock entity for testing that the trait compiles and works.
139    struct CounterEntity;
140
141    #[async_trait]
142    impl Entity for CounterEntity {
143        fn entity_type(&self) -> EntityType {
144            EntityType::new("Counter")
145        }
146
147        async fn spawn(&self, _ctx: EntityContext) -> Result<Box<dyn EntityHandler>, ClusterError> {
148            Ok(Box::new(CounterHandler))
149        }
150    }
151
152    struct CounterHandler;
153
154    #[async_trait]
155    impl EntityHandler for CounterHandler {
156        async fn handle_request(
157            &self,
158            tag: &str,
159            _payload: &[u8],
160            _headers: &HashMap<String, String>,
161        ) -> Result<Vec<u8>, ClusterError> {
162            match tag {
163                "increment" => Ok(rmp_serde::to_vec(&1i32).unwrap()),
164                _ => Err(ClusterError::MalformedMessage {
165                    reason: format!("unknown tag: {tag}"),
166                    source: None,
167                }),
168            }
169        }
170    }
171
172    #[test]
173    fn default_shard_group_returns_default() {
174        let entity = CounterEntity;
175        assert_eq!(entity.shard_group(), "default");
176    }
177
178    #[test]
179    fn default_shard_group_for_delegates() {
180        let entity = CounterEntity;
181        let id = EntityId::new("some-id");
182        assert_eq!(entity.shard_group_for(&id), "default");
183    }
184
185    #[test]
186    fn default_max_idle_time_is_none() {
187        let entity = CounterEntity;
188        assert!(entity.max_idle_time().is_none());
189    }
190
191    #[test]
192    fn default_mailbox_capacity_is_none() {
193        let entity = CounterEntity;
194        assert!(entity.mailbox_capacity().is_none());
195    }
196
197    /// Custom entity that overrides defaults.
198    struct CustomEntity;
199
200    #[async_trait]
201    impl Entity for CustomEntity {
202        fn entity_type(&self) -> EntityType {
203            EntityType::new("Custom")
204        }
205
206        fn shard_group(&self) -> &str {
207            "premium"
208        }
209
210        fn max_idle_time(&self) -> Option<Duration> {
211            Some(Duration::from_secs(120))
212        }
213
214        fn mailbox_capacity(&self) -> Option<usize> {
215            Some(50)
216        }
217
218        async fn spawn(&self, _ctx: EntityContext) -> Result<Box<dyn EntityHandler>, ClusterError> {
219            Ok(Box::new(CounterHandler))
220        }
221    }
222
223    #[test]
224    fn custom_shard_group() {
225        let entity = CustomEntity;
226        assert_eq!(entity.shard_group(), "premium");
227        assert_eq!(entity.shard_group_for(&EntityId::new("x")), "premium");
228    }
229
230    #[test]
231    fn custom_max_idle_time() {
232        let entity = CustomEntity;
233        assert_eq!(entity.max_idle_time(), Some(Duration::from_secs(120)));
234    }
235
236    #[test]
237    fn custom_mailbox_capacity() {
238        let entity = CustomEntity;
239        assert_eq!(entity.mailbox_capacity(), Some(50));
240    }
241
242    #[tokio::test]
243    async fn spawn_and_handle_request() {
244        let entity = CounterEntity;
245        let ctx = EntityContext {
246            address: EntityAddress {
247                shard_id: ShardId::new("default", 0),
248                entity_type: EntityType::new("Counter"),
249                entity_id: EntityId::new("c-1"),
250            },
251            runner_address: RunnerAddress::new("127.0.0.1", 9000),
252            snowflake: Arc::new(SnowflakeGenerator::new()),
253            cancellation: tokio_util::sync::CancellationToken::new(),
254            state_storage: None,
255            workflow_engine: None,
256            sharding: None,
257            message_storage: None,
258        };
259        let handler = entity.spawn(ctx).await.unwrap();
260        let result = handler
261            .handle_request("increment", &[], &HashMap::new())
262            .await
263            .unwrap();
264        let value: i32 = rmp_serde::from_slice(&result).unwrap();
265        assert_eq!(value, 1);
266    }
267
268    #[tokio::test]
269    async fn handle_unknown_tag_returns_error() {
270        let entity = CounterEntity;
271        let ctx = EntityContext {
272            address: EntityAddress {
273                shard_id: ShardId::new("default", 0),
274                entity_type: EntityType::new("Counter"),
275                entity_id: EntityId::new("c-1"),
276            },
277            runner_address: RunnerAddress::new("127.0.0.1", 9000),
278            snowflake: Arc::new(SnowflakeGenerator::new()),
279            cancellation: tokio_util::sync::CancellationToken::new(),
280            state_storage: None,
281            workflow_engine: None,
282            sharding: None,
283            message_storage: None,
284        };
285        let handler = entity.spawn(ctx).await.unwrap();
286        let err = handler
287            .handle_request("unknown", &[], &HashMap::new())
288            .await
289            .unwrap_err();
290        assert!(matches!(err, ClusterError::MalformedMessage { .. }));
291    }
292
293    #[tokio::test]
294    async fn default_handle_stream_wraps_request() {
295        use tokio_stream::StreamExt;
296
297        let entity = CounterEntity;
298        let ctx = EntityContext {
299            address: EntityAddress {
300                shard_id: ShardId::new("default", 0),
301                entity_type: EntityType::new("Counter"),
302                entity_id: EntityId::new("c-1"),
303            },
304            runner_address: RunnerAddress::new("127.0.0.1", 9000),
305            snowflake: Arc::new(SnowflakeGenerator::new()),
306            cancellation: tokio_util::sync::CancellationToken::new(),
307            state_storage: None,
308            workflow_engine: None,
309            sharding: None,
310            message_storage: None,
311        };
312        let handler = entity.spawn(ctx).await.unwrap();
313        let mut stream = handler
314            .handle_stream("increment", &[], &HashMap::new())
315            .await
316            .unwrap();
317
318        let first = stream.next().await.unwrap().unwrap();
319        let value: i32 = rmp_serde::from_slice(&first).unwrap();
320        assert_eq!(value, 1);
321
322        // Stream should be exhausted
323        assert!(stream.next().await.is_none());
324    }
325
326    #[tokio::test]
327    async fn default_on_idle_returns_false() {
328        let handler = CounterHandler;
329        assert!(!handler.on_idle().await);
330    }
331}