1use crate::error::ClusterError;
2use crate::sharding::Sharding;
3use crate::snowflake::SnowflakeGenerator;
4use crate::types::{EntityAddress, EntityId, EntityType, RunnerAddress};
5use async_trait::async_trait;
6use std::collections::HashMap;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio_stream::Stream;
11
12#[derive(Clone)]
14pub struct EntityContext {
15 pub address: EntityAddress,
17 pub runner_address: RunnerAddress,
19 pub snowflake: Arc<SnowflakeGenerator>,
21 pub cancellation: tokio_util::sync::CancellationToken,
23 pub state_storage: Option<Arc<dyn crate::__internal::WorkflowStorage>>,
29 pub workflow_engine: Option<Arc<dyn crate::__internal::WorkflowEngine>>,
35 pub sharding: Option<Arc<dyn Sharding>>,
40 pub message_storage: Option<Arc<dyn crate::message_storage::MessageStorage>>,
47}
48
49#[async_trait]
55pub trait Entity: Send + Sync + 'static {
56 fn entity_type(&self) -> EntityType;
58
59 fn shard_group(&self) -> &str {
61 "default"
62 }
63
64 fn shard_group_for(&self, _entity_id: &EntityId) -> &str {
66 self.shard_group()
67 }
68
69 fn max_idle_time(&self) -> Option<Duration> {
71 None
72 }
73
74 fn mailbox_capacity(&self) -> Option<usize> {
76 None
77 }
78
79 fn concurrency(&self) -> Option<usize> {
89 None
90 }
91
92 async fn spawn(&self, ctx: EntityContext) -> Result<Box<dyn EntityHandler>, ClusterError>;
95}
96
97#[async_trait]
103pub trait EntityHandler: Send + Sync {
104 async fn handle_request(
106 &self,
107 tag: &str,
108 payload: &[u8],
109 headers: &HashMap<String, String>,
110 ) -> Result<Vec<u8>, ClusterError>;
111
112 async fn handle_stream(
116 &self,
117 tag: &str,
118 payload: &[u8],
119 headers: &HashMap<String, String>,
120 ) -> Result<Pin<Box<dyn Stream<Item = Result<Vec<u8>, ClusterError>> + Send>>, ClusterError>
121 {
122 let result = self.handle_request(tag, payload, headers).await?;
123 Ok(Box::pin(tokio_stream::once(Ok(result))))
124 }
125
126 async fn on_idle(&self) -> bool {
129 false
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136 use crate::types::{EntityId, EntityType, ShardId};
137
138 struct CounterEntity;
140
141 #[async_trait]
142 impl Entity for CounterEntity {
143 fn entity_type(&self) -> EntityType {
144 EntityType::new("Counter")
145 }
146
147 async fn spawn(&self, _ctx: EntityContext) -> Result<Box<dyn EntityHandler>, ClusterError> {
148 Ok(Box::new(CounterHandler))
149 }
150 }
151
152 struct CounterHandler;
153
154 #[async_trait]
155 impl EntityHandler for CounterHandler {
156 async fn handle_request(
157 &self,
158 tag: &str,
159 _payload: &[u8],
160 _headers: &HashMap<String, String>,
161 ) -> Result<Vec<u8>, ClusterError> {
162 match tag {
163 "increment" => Ok(rmp_serde::to_vec(&1i32).unwrap()),
164 _ => Err(ClusterError::MalformedMessage {
165 reason: format!("unknown tag: {tag}"),
166 source: None,
167 }),
168 }
169 }
170 }
171
172 #[test]
173 fn default_shard_group_returns_default() {
174 let entity = CounterEntity;
175 assert_eq!(entity.shard_group(), "default");
176 }
177
178 #[test]
179 fn default_shard_group_for_delegates() {
180 let entity = CounterEntity;
181 let id = EntityId::new("some-id");
182 assert_eq!(entity.shard_group_for(&id), "default");
183 }
184
185 #[test]
186 fn default_max_idle_time_is_none() {
187 let entity = CounterEntity;
188 assert!(entity.max_idle_time().is_none());
189 }
190
191 #[test]
192 fn default_mailbox_capacity_is_none() {
193 let entity = CounterEntity;
194 assert!(entity.mailbox_capacity().is_none());
195 }
196
197 struct CustomEntity;
199
200 #[async_trait]
201 impl Entity for CustomEntity {
202 fn entity_type(&self) -> EntityType {
203 EntityType::new("Custom")
204 }
205
206 fn shard_group(&self) -> &str {
207 "premium"
208 }
209
210 fn max_idle_time(&self) -> Option<Duration> {
211 Some(Duration::from_secs(120))
212 }
213
214 fn mailbox_capacity(&self) -> Option<usize> {
215 Some(50)
216 }
217
218 async fn spawn(&self, _ctx: EntityContext) -> Result<Box<dyn EntityHandler>, ClusterError> {
219 Ok(Box::new(CounterHandler))
220 }
221 }
222
223 #[test]
224 fn custom_shard_group() {
225 let entity = CustomEntity;
226 assert_eq!(entity.shard_group(), "premium");
227 assert_eq!(entity.shard_group_for(&EntityId::new("x")), "premium");
228 }
229
230 #[test]
231 fn custom_max_idle_time() {
232 let entity = CustomEntity;
233 assert_eq!(entity.max_idle_time(), Some(Duration::from_secs(120)));
234 }
235
236 #[test]
237 fn custom_mailbox_capacity() {
238 let entity = CustomEntity;
239 assert_eq!(entity.mailbox_capacity(), Some(50));
240 }
241
242 #[tokio::test]
243 async fn spawn_and_handle_request() {
244 let entity = CounterEntity;
245 let ctx = EntityContext {
246 address: EntityAddress {
247 shard_id: ShardId::new("default", 0),
248 entity_type: EntityType::new("Counter"),
249 entity_id: EntityId::new("c-1"),
250 },
251 runner_address: RunnerAddress::new("127.0.0.1", 9000),
252 snowflake: Arc::new(SnowflakeGenerator::new()),
253 cancellation: tokio_util::sync::CancellationToken::new(),
254 state_storage: None,
255 workflow_engine: None,
256 sharding: None,
257 message_storage: None,
258 };
259 let handler = entity.spawn(ctx).await.unwrap();
260 let result = handler
261 .handle_request("increment", &[], &HashMap::new())
262 .await
263 .unwrap();
264 let value: i32 = rmp_serde::from_slice(&result).unwrap();
265 assert_eq!(value, 1);
266 }
267
268 #[tokio::test]
269 async fn handle_unknown_tag_returns_error() {
270 let entity = CounterEntity;
271 let ctx = EntityContext {
272 address: EntityAddress {
273 shard_id: ShardId::new("default", 0),
274 entity_type: EntityType::new("Counter"),
275 entity_id: EntityId::new("c-1"),
276 },
277 runner_address: RunnerAddress::new("127.0.0.1", 9000),
278 snowflake: Arc::new(SnowflakeGenerator::new()),
279 cancellation: tokio_util::sync::CancellationToken::new(),
280 state_storage: None,
281 workflow_engine: None,
282 sharding: None,
283 message_storage: None,
284 };
285 let handler = entity.spawn(ctx).await.unwrap();
286 let err = handler
287 .handle_request("unknown", &[], &HashMap::new())
288 .await
289 .unwrap_err();
290 assert!(matches!(err, ClusterError::MalformedMessage { .. }));
291 }
292
293 #[tokio::test]
294 async fn default_handle_stream_wraps_request() {
295 use tokio_stream::StreamExt;
296
297 let entity = CounterEntity;
298 let ctx = EntityContext {
299 address: EntityAddress {
300 shard_id: ShardId::new("default", 0),
301 entity_type: EntityType::new("Counter"),
302 entity_id: EntityId::new("c-1"),
303 },
304 runner_address: RunnerAddress::new("127.0.0.1", 9000),
305 snowflake: Arc::new(SnowflakeGenerator::new()),
306 cancellation: tokio_util::sync::CancellationToken::new(),
307 state_storage: None,
308 workflow_engine: None,
309 sharding: None,
310 message_storage: None,
311 };
312 let handler = entity.spawn(ctx).await.unwrap();
313 let mut stream = handler
314 .handle_stream("increment", &[], &HashMap::new())
315 .await
316 .unwrap();
317
318 let first = stream.next().await.unwrap().unwrap();
319 let value: i32 = rmp_serde::from_slice(&first).unwrap();
320 assert_eq!(value, 1);
321
322 assert!(stream.next().await.is_none());
324 }
325
326 #[tokio::test]
327 async fn default_on_idle_returns_false() {
328 let handler = CounterHandler;
329 assert!(!handler.on_idle().await);
330 }
331}