Skip to main content

aj_core/
aj.rs

1use dashmap::DashMap;
2use kameo::actor::ActorRef;
3use kameo::message::{Context, Message};
4use kameo::Actor;
5use lazy_static::lazy_static;
6use serde::de::DeserializeOwned;
7use serde::Serialize;
8use std::any::{Any, TypeId};
9use std::sync::{Arc, RwLock};
10
11use crate::job::Job;
12use crate::mem::InMemory;
13use crate::queue::{cancel_job, enqueue_job, WorkQueue};
14use crate::types::Backend;
15use crate::{
16    get_job, retry_job, update_work_queue_config, BackgroundJob, EnqueueConfig, Error, Executable,
17    JobContext, JobPlugin, JobPluginWrapper, PluginCenter, WorkQueueConfig,
18};
19
20lazy_static! {
21    static ref QUEUE_REGISTRY: Registry = Registry::default();
22}
23
24lazy_static! {
25    static ref AJ_BACKEND: Arc<RwLock<Option<Arc<dyn Backend + Send + Sync + 'static>>>> =
26        Arc::new(RwLock::new(None));
27}
28
29lazy_static! {
30    static ref AJ_ADDR: Arc<RwLock<Option<ActorRef<AJ>>>> = Arc::new(RwLock::new(None));
31}
32
33#[derive(Debug, Default)]
34pub struct Registry {
35    registry: DashMap<TypeId, Box<dyn Any + Send + Sync>>,
36    registry_by_name: DashMap<String, Box<dyn Any + Send + Sync>>,
37}
38
39fn get_backend() -> Option<Arc<dyn Backend + Send + Sync + 'static>> {
40    if let Ok(backend) = AJ_BACKEND.try_read() {
41        backend.clone()
42    } else {
43        None
44    }
45}
46
47pub fn get_work_queue_address<M>() -> Option<ActorRef<WorkQueue<M>>>
48where
49    M: Executable + Send + Sync + Clone + Serialize + DeserializeOwned + 'static,
50{
51    let type_id = TypeId::of::<M>();
52    if let Some(queue_addr) = QUEUE_REGISTRY.registry.get(&type_id) {
53        if let Some(addr) = queue_addr.downcast_ref::<ActorRef<WorkQueue<M>>>() {
54            return Some(addr.clone());
55        }
56    }
57
58    None
59}
60
61fn register_work_queue<M>(queue_name: &str) -> ActorRef<WorkQueue<M>>
62where
63    M: Executable + Send + Sync + Clone + Serialize + DeserializeOwned + 'static,
64{
65    let type_id = TypeId::of::<M>();
66    let registry = &QUEUE_REGISTRY;
67
68    // Check if already registered
69    if let Some(queue_addr) = registry.registry.get(&type_id) {
70        if let Some(addr) = queue_addr.downcast_ref::<ActorRef<WorkQueue<M>>>() {
71            return addr.clone();
72        }
73    }
74
75    let backend = get_backend().expect("AJ is not started, please start it via AJ::start");
76
77    // Start Queue actor
78    let queue_ref = WorkQueue::<M>::start_with_name(queue_name.into(), backend);
79    registry
80        .registry
81        .insert(type_id, Box::new(queue_ref.clone()));
82    registry
83        .registry_by_name
84        .insert(queue_name.into(), Box::new(queue_ref.clone()));
85
86    queue_ref
87}
88
89pub fn get_aj_address() -> Option<ActorRef<AJ>> {
90    if let Ok(addr) = AJ_ADDR.try_read() {
91        addr.clone()
92    } else {
93        None
94    }
95}
96
97#[derive(Actor)]
98pub struct AJ {}
99
100impl AJ {
101    /// Start AJ with a custom backend
102    pub fn start(backend: impl Backend + Send + Sync + 'static) -> ActorRef<Self> {
103        if let Some(aj_addr) = get_aj_address() {
104            warn!("AJ is running. Return current AJ");
105            return aj_addr;
106        }
107
108        // Store backend globally
109        if let Ok(ref mut backend_ref) = AJ_BACKEND.try_write() {
110            **backend_ref = Some(Arc::new(backend));
111        }
112
113        let actor = AJ {};
114        let actor_ref = kameo::spawn(actor);
115
116        if let Ok(ref mut aj_addr) = AJ_ADDR.try_write() {
117            **aj_addr = Some(actor_ref.clone());
118        }
119
120        actor_ref
121    }
122
123    /// Quick start AJ with in-memory backend
124    pub fn quick_start() -> ActorRef<Self> {
125        Self::start(InMemory::default())
126    }
127
128    pub async fn enqueue_job<M>(
129        job: Job<M>,
130        config: EnqueueConfig,
131        queue_name: &str,
132    ) -> Result<(), Error>
133    where
134        M: Executable + Send + Sync + Clone + Serialize + DeserializeOwned + 'static,
135    {
136        let actor_ref = if let Some(actor_ref) = get_work_queue_address() {
137            actor_ref
138        } else {
139            info!("Not found WorkQueue for {}, creating...", queue_name);
140            register_work_queue::<M>(queue_name)
141        };
142        enqueue_job(actor_ref, job, config).await
143    }
144
145    pub async fn cancel_job<M>(job_id: String) -> Result<(), Error>
146    where
147        M: Executable + Send + Sync + Clone + Serialize + DeserializeOwned + 'static,
148    {
149        let actor_ref: Option<ActorRef<WorkQueue<M>>> = get_work_queue_address();
150        if let Some(queue_ref) = actor_ref {
151            cancel_job(queue_ref, job_id).await
152        } else {
153            Err(Error::NoQueueRegister)
154        }
155    }
156
157    pub async fn get_job<M>(job_id: &str) -> Option<Job<M>>
158    where
159        M: Executable + Send + Sync + Clone + Serialize + DeserializeOwned + 'static,
160    {
161        let actor_ref: Option<ActorRef<WorkQueue<M>>> = get_work_queue_address();
162        if let Some(queue_ref) = actor_ref {
163            get_job(queue_ref, job_id).await
164        } else {
165            None
166        }
167    }
168
169    pub async fn update_job<M>(
170        job_id: &str,
171        data: M,
172        context: Option<JobContext>,
173    ) -> Result<(), Error>
174    where
175        M: Executable
176            + BackgroundJob
177            + Send
178            + Sync
179            + Clone
180            + Serialize
181            + DeserializeOwned
182            + 'static,
183    {
184        let job = Self::get_job::<M>(job_id).await;
185        if let Some(mut job) = job {
186            job.data = data;
187            if let Some(context) = context {
188                job.context = context;
189            }
190            Self::add_job(job, M::queue_name()).await?;
191        } else {
192            warn!("Cannot update non existing job {job_id}");
193        }
194
195        Ok(())
196    }
197
198    pub async fn retry_job<M>(job_id: &str) -> Result<bool, Error>
199    where
200        M: Executable + Send + Sync + Clone + Serialize + DeserializeOwned + 'static,
201    {
202        let actor_ref: Option<ActorRef<WorkQueue<M>>> = get_work_queue_address();
203        if let Some(queue_ref) = actor_ref {
204            retry_job(queue_ref, job_id).await
205        } else {
206            Err(Error::NoQueueRegister)
207        }
208    }
209
210    pub async fn add_job<M>(job: Job<M>, queue_name: &str) -> Result<String, Error>
211    where
212        M: Executable + Send + Sync + Clone + Serialize + DeserializeOwned + 'static,
213    {
214        let job_id = job.id().to_string();
215        let config = EnqueueConfig::new_re_run();
216        Self::enqueue_job(job, config, queue_name).await?;
217        Ok(job_id)
218    }
219
220    pub async fn update_work_queue<M>(config: WorkQueueConfig) -> Result<(), Error>
221    where
222        M: Executable + Send + Sync + Clone + Serialize + DeserializeOwned + 'static,
223    {
224        let actor_ref: Option<ActorRef<WorkQueue<M>>> = get_work_queue_address();
225        if let Some(queue_ref) = actor_ref {
226            update_work_queue_config(queue_ref, config).await
227        } else {
228            Err(Error::NoQueueRegister)
229        }
230    }
231
232    pub async fn register_plugin(
233        plugin: impl JobPlugin + Send + Sync + 'static,
234    ) -> Result<(), Error> {
235        let wrapper = JobPluginWrapper::new(plugin, vec![]);
236        PluginCenter::register(wrapper).await
237    }
238}
239
240// Message: JustRunJob (fire and forget style)
241pub struct JustRunJob<M>
242where
243    M: Executable + Send + Sync + Clone + Serialize + DeserializeOwned + 'static,
244{
245    pub job: Job<M>,
246    pub queue_name: String,
247}
248
249impl<M> Message<JustRunJob<M>> for AJ
250where
251    M: Executable + Send + Sync + Clone + Serialize + DeserializeOwned + 'static,
252{
253    type Reply = ();
254
255    async fn handle(
256        &mut self,
257        msg: JustRunJob<M>,
258        _ctx: Context<'_, Self, Self::Reply>,
259    ) -> Self::Reply {
260        if let Err(reason) = AJ::add_job(msg.job, &msg.queue_name).await {
261            error!("Cannot start job {reason:?}");
262        }
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::{get_aj_address, AJ};
269
270    #[tokio::test]
271    async fn test_start_aj_under_tokio_runtime() {
272        let _actor_ref = AJ::quick_start();
273        let register_ref = get_aj_address();
274
275        assert!(register_ref.is_some());
276    }
277}