1use std::time::Duration as StdDuration;
8
9use chrono::{Duration, Utc};
10use diesel::ExpressionMethods;
11use diesel::OptionalExtension;
12use diesel::QueryDsl;
13use diesel_async::AsyncPgConnection;
14use diesel_async::RunQueryDsl;
15use uuid::Uuid;
16
17use crate::error::HarvestResult;
18use crate::models::{NewTaskQueueItem, TaskQueueItem};
19use crate::telemetry::TraceContextCarrier;
20use crate::types::{ExecutionId, Priority};
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum TaskType {
29 Workflow,
31 Activity,
33}
34
35const IMMEDIATE_SCHEDULE_SKEW_ALLOWANCE: Duration = Duration::seconds(5);
36
37impl TaskType {
38 #[must_use]
42 pub const fn as_str(self) -> &'static str {
43 match self {
44 Self::Workflow => "workflow",
45 Self::Activity => "activity",
46 }
47 }
48}
49
50impl std::fmt::Display for TaskType {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 f.write_str(self.as_str())
53 }
54}
55
56#[derive(Debug, Clone)]
62pub struct EnqueueParams {
63 pub queue_name: String,
64 pub task_type: TaskType,
65 pub workflow_exec_id: Option<Uuid>,
66 pub activity_name: Option<String>,
67 pub activity_id: Option<Uuid>,
68 pub input: serde_json::Value,
69 pub priority: i32,
70 pub max_attempts: i32,
71 pub scheduled_at: chrono::DateTime<Utc>,
72 pub heartbeat_timeout: Option<Duration>,
73 pub start_to_close: Option<Duration>,
74 pub schedule_to_start: Option<Duration>,
75 pub retry_policy: Option<serde_json::Value>,
76 pub sticky_worker_id: Option<String>,
82 pub sticky_timeout: Option<StdDuration>,
86 pub trace_context: Option<TraceContextCarrier>,
89 pub concurrency_key: Option<String>,
94 pub max_concurrent: Option<u32>,
98 pub required_build_id: Option<String>,
104 pub rate_limit_key: Option<String>,
106 pub schedule_to_close_at: Option<chrono::DateTime<Utc>>,
110 pub required_capabilities: Option<serde_json::Value>,
112 pub context_headers: Option<serde_json::Value>,
114}
115
116impl EnqueueParams {
117 #[must_use]
119 pub fn new(
120 queue_name: impl Into<String>,
121 task_type: TaskType,
122 input: serde_json::Value,
123 ) -> Self {
124 Self {
125 queue_name: queue_name.into(),
126 task_type,
127 workflow_exec_id: None,
128 activity_name: None,
129 activity_id: None,
130 input,
131 priority: 0,
132 max_attempts: 3,
133 scheduled_at: Utc::now() - IMMEDIATE_SCHEDULE_SKEW_ALLOWANCE,
136 heartbeat_timeout: None,
137 start_to_close: None,
138 schedule_to_start: None,
139 retry_policy: None,
140 sticky_worker_id: None,
141 sticky_timeout: None,
142 trace_context: None,
143 concurrency_key: None,
144 max_concurrent: None,
145 required_build_id: None,
146 rate_limit_key: None,
147 schedule_to_close_at: None,
148 required_capabilities: None,
149 context_headers: None,
150 }
151 }
152
153 #[must_use]
159 pub const fn with_priority(mut self, priority: Priority) -> Self {
160 self.priority = priority.as_i32();
161 self
162 }
163
164 #[must_use]
168 pub fn with_sticky(mut self, worker_id: impl Into<String>, timeout: StdDuration) -> Self {
169 self.sticky_worker_id = Some(worker_id.into());
170 self.sticky_timeout = Some(timeout);
171 self
172 }
173
174 #[must_use]
177 pub fn with_trace_context(mut self, carrier: TraceContextCarrier) -> Self {
178 self.trace_context = Some(carrier);
179 self
180 }
181}
182
183pub async fn enqueue(conn: &mut AsyncPgConnection, params: &EnqueueParams) -> HarvestResult<Uuid> {
193 use crate::schema::harvest_task_queue;
194
195 let task_id = Uuid::new_v4();
196
197 let sticky = match (params.sticky_worker_id.as_deref(), params.sticky_timeout) {
202 (Some(worker), Some(timeout)) => {
203 let chrono_timeout = chrono::Duration::from_std(timeout).map_err(|_| {
204 crate::error::HarvestError::Config(
205 "sticky_timeout exceeds chrono duration range".to_string(),
206 )
207 })?;
208 Some((worker, chrono_timeout))
209 }
210 _ => None,
211 };
212
213 let concurrency_cap = params
214 .max_concurrent
215 .map(|n| i32::try_from(n).unwrap_or(i32::MAX));
216
217 let row = NewTaskQueueItem {
218 id: task_id,
219 queue_name: ¶ms.queue_name,
220 task_type: params.task_type.as_str(),
221 workflow_exec_id: params.workflow_exec_id,
222 activity_name: params.activity_name.as_deref(),
223 activity_id: params.activity_id,
224 input: params.input.clone(),
225 priority: params.priority,
226 max_attempts: params.max_attempts,
227 scheduled_at: params.scheduled_at,
228 heartbeat_timeout: params.heartbeat_timeout,
229 start_to_close: params.start_to_close,
230 schedule_to_start: params.schedule_to_start,
231 retry_policy: params.retry_policy.clone(),
232 heartbeat_details: None,
233 sticky_worker_id: None,
234 sticky_until: None,
235 sticky_timeout: None,
236 trace_context: params
237 .trace_context
238 .as_ref()
239 .and_then(TraceContextCarrier::to_json),
240 concurrency_key: params.concurrency_key.as_deref(),
241 concurrency_cap,
242 required_build_id: params.required_build_id.as_deref(),
243 rate_limit_key: params.rate_limit_key.as_deref(),
244 schedule_to_close_at: params.schedule_to_close_at,
245 required_capabilities: params.required_capabilities.clone(),
246 context_headers: params.context_headers.clone(),
247 };
248
249 diesel::insert_into(harvest_task_queue::table)
250 .values(&row)
251 .execute(conn)
252 .await
253 .map_err(crate::error::database_error)?;
254
255 if let Some((worker_id, timeout)) = sticky {
256 diesel::sql_query(
257 "UPDATE harvest_task_queue \
258 SET sticky_worker_id = $2, \
259 sticky_until = NOW() + $3, \
260 sticky_timeout = $3 \
261 WHERE id = $1",
262 )
263 .bind::<diesel::sql_types::Uuid, _>(task_id)
264 .bind::<diesel::sql_types::Text, _>(worker_id)
265 .bind::<diesel::sql_types::Interval, _>(timeout)
266 .execute(conn)
267 .await
268 .map_err(crate::error::database_error)?;
269 }
270
271 crate::notify::notify_task_enqueued(conn, ¶ms.queue_name, task_id).await?;
272
273 Ok(task_id)
274}
275
276#[allow(clippy::too_many_lines)]
301pub async fn claim_task(
302 conn: &mut AsyncPgConnection,
303 queues: &[String],
304 worker_id: &str,
305 worker_build_id: &str,
306 priority_aging_secs: Option<u32>,
307 circuit_breaker_activities: &[String],
308 ineligible_activities: &[String],
309) -> HarvestResult<Option<TaskQueueItem>> {
310 let aging_secs_i64: Option<i64> = priority_aging_secs.map(i64::from);
366
367 let result: Vec<TaskQueueItem> = diesel::sql_query(
368 "WITH worker_info AS ( \
369 SELECT COALESCE((SELECT labels FROM harvest_workers WHERE worker_id = $1), '{}'::jsonb) AS labels \
370 ), \
371 candidate AS ( \
372 SELECT id, task_type, concurrency_key, concurrency_cap, rate_limit_key, activity_name \
373 FROM harvest_task_queue \
374 CROSS JOIN worker_info \
375 WHERE queue_name = ANY($2) \
376 AND state = 'PENDING' \
377 AND scheduled_at <= NOW() \
378 AND ( \
379 schedule_to_close_at IS NULL \
380 OR schedule_to_close_at > NOW() \
381 ) \
382 AND ( \
383 sticky_worker_id IS NULL \
384 OR sticky_worker_id = $1 \
385 OR sticky_until IS NULL \
386 OR sticky_until <= NOW() \
387 ) \
388 AND ( \
389 concurrency_key IS NULL \
390 OR concurrency_cap IS NULL \
391 OR ( \
392 SELECT COUNT(*) FROM harvest_task_queue inner_q \
393 WHERE inner_q.concurrency_key = harvest_task_queue.concurrency_key \
394 AND inner_q.task_type = harvest_task_queue.task_type \
395 AND inner_q.state = 'RUNNING' \
396 AND inner_q.worker_id IS NOT NULL \
397 ) < harvest_task_queue.concurrency_cap \
398 ) \
399 AND ( \
400 required_build_id IS NULL \
401 OR $3 = '' \
402 OR required_build_id = $3 \
403 OR EXISTS ( \
404 SELECT 1 FROM harvest_build_compat \
405 WHERE build_id = $3 \
406 AND compatible_with = harvest_task_queue.required_build_id \
407 ) \
408 ) \
409 AND ( \
410 task_type <> 'workflow' \
411 OR workflow_exec_id IS NULL \
412 OR NOT EXISTS ( \
413 SELECT 1 FROM harvest_workflow_executions e \
414 WHERE e.id = harvest_task_queue.workflow_exec_id \
415 AND e.state = 'PAUSED' \
416 ) \
417 ) \
418 AND ( \
419 task_type != 'activity' \
420 OR activity_name IS NULL \
421 OR required_capabilities IS NOT NULL \
422 OR NOT (activity_name = ANY($6)) \
423 ) \
424 AND ( \
425 required_capabilities IS NULL \
426 OR NOT EXISTS ( \
427 SELECT 1 \
428 FROM jsonb_array_elements(required_capabilities) AS r(value) \
429 WHERE ( \
430 r.value ? 'Exact' AND ( \
431 worker_info.labels->>(r.value->'Exact'->>'key') IS NULL \
432 OR worker_info.labels->>(r.value->'Exact'->>'key') != (r.value->'Exact'->>'value') \
433 ) \
434 ) OR ( \
435 r.value ? 'In' AND ( \
436 worker_info.labels->>(r.value->'In'->>'key') IS NULL \
437 OR NOT ( \
438 (r.value->'In'->'values') @> jsonb_build_array(worker_info.labels->>(r.value->'In'->>'key')) \
439 ) \
440 ) \
441 ) \
442 ) \
443 ) \
444 AND ( \
445 rate_limit_key IS NULL \
446 OR harvest_task_queue.activity_name = ANY($5) \
447 OR EXISTS ( \
448 SELECT 1 FROM harvest_rate_limit_buckets b \
449 WHERE b.key = harvest_task_queue.rate_limit_key \
450 AND LEAST(b.burst, b.tokens + EXTRACT(EPOCH FROM (NOW() - b.last_refilled_at)) * b.refill_rate) >= 1.0 \
451 ) \
452 ) \
453 ORDER BY \
454 CASE \
455 WHEN sticky_worker_id = $1 AND sticky_until > NOW() THEN 1 \
456 ELSE 0 \
457 END DESC, \
458 CASE \
459 WHEN $4::BIGINT IS NOT NULL AND $4::BIGINT > 0 \
460 THEN priority + FLOOR(EXTRACT(EPOCH FROM (NOW() - scheduled_at)) / $4::BIGINT)::INT \
461 ELSE priority \
462 END DESC, \
463 scheduled_at ASC \
464 LIMIT 1 FOR UPDATE SKIP LOCKED \
465 ), \
466 rate_limit_debit AS ( \
467 UPDATE harvest_rate_limit_buckets b \
468 SET tokens = LEAST(b.burst, b.tokens + EXTRACT(EPOCH FROM (NOW() - b.last_refilled_at)) * b.refill_rate) - 1.0, \
469 last_refilled_at = NOW() \
470 FROM candidate \
471 WHERE b.key = candidate.rate_limit_key \
472 AND NOT (candidate.activity_name = ANY($5)) \
473 AND LEAST(b.burst, b.tokens + EXTRACT(EPOCH FROM (NOW() - b.last_refilled_at)) * b.refill_rate) >= 1.0 \
474 RETURNING b.key AS debited_key \
475 ), \
476 claimed AS ( \
477 UPDATE harvest_task_queue \
478 SET state = 'RUNNING', worker_id = $1, started_at = NOW(), attempt = attempt + 1 \
479 FROM candidate \
480 WHERE harvest_task_queue.id = candidate.id \
481 AND ( \
482 candidate.concurrency_key IS NULL \
483 OR ( \
484 pg_try_advisory_xact_lock(hashtext(candidate.concurrency_key)::bigint) \
485 AND ( \
486 candidate.concurrency_cap IS NULL \
487 OR ( \
488 SELECT COUNT(*) FROM harvest_task_queue recheck \
489 WHERE recheck.concurrency_key = candidate.concurrency_key \
490 AND recheck.task_type = candidate.task_type \
491 AND recheck.state = 'RUNNING' \
492 AND recheck.worker_id IS NOT NULL \
493 ) < candidate.concurrency_cap \
494 ) \
495 ) \
496 ) \
497 AND ( \
498 candidate.rate_limit_key IS NULL \
499 OR candidate.activity_name = ANY($5) \
500 OR EXISTS (SELECT 1 FROM rate_limit_debit WHERE debited_key = candidate.rate_limit_key) \
501 ) \
502 RETURNING harvest_task_queue.* \
503 ) \
504 SELECT * FROM claimed",
505 )
506 .bind::<diesel::sql_types::Text, _>(worker_id)
507 .bind::<diesel::sql_types::Array<diesel::sql_types::Text>, _>(queues)
508 .bind::<diesel::sql_types::Text, _>(worker_build_id)
509 .bind::<diesel::sql_types::Nullable<diesel::sql_types::BigInt>, _>(aging_secs_i64)
510 .bind::<diesel::sql_types::Array<diesel::sql_types::Text>, _>(circuit_breaker_activities)
511 .bind::<diesel::sql_types::Array<diesel::sql_types::Text>, _>(ineligible_activities)
512 .load(conn)
513 .await
514 .map_err(crate::error::database_error)?;
515
516 Ok(result.into_iter().next())
517}
518
519#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
526pub struct QueueScalingSignal {
527 pub queue: String,
529 pub backlog: i64,
531 pub in_flight: i64,
533 pub scheduled: i64,
535 pub active_workers: i64,
537}
538
539#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
541pub struct QueueTaskCounts {
542 pub queue: String,
544 pub backlog: i64,
546 pub in_flight: i64,
548 pub scheduled: i64,
550}
551
552pub async fn queue_task_counts(
558 conn: &mut AsyncPgConnection,
559) -> HarvestResult<Vec<QueueTaskCounts>> {
560 #[derive(diesel::QueryableByName)]
561 struct Row {
562 #[diesel(sql_type = diesel::sql_types::Text)]
563 queue: String,
564 #[diesel(sql_type = diesel::sql_types::BigInt)]
565 backlog: i64,
566 #[diesel(sql_type = diesel::sql_types::BigInt)]
567 in_flight: i64,
568 #[diesel(sql_type = diesel::sql_types::BigInt)]
569 scheduled: i64,
570 }
571
572 let rows: Vec<Row> = diesel::sql_query(
573 "SELECT \
574 queue_name AS queue, \
575 COUNT(*) FILTER (WHERE state = 'PENDING' AND scheduled_at <= $1) AS backlog, \
576 COUNT(*) FILTER (WHERE state = 'RUNNING') AS in_flight, \
577 COUNT(*) FILTER (WHERE state = 'PENDING' AND scheduled_at > $1) AS scheduled \
578 FROM harvest_task_queue \
579 GROUP BY queue_name",
580 )
581 .bind::<diesel::sql_types::Timestamptz, _>(Utc::now())
582 .load(conn)
583 .await
584 .map_err(crate::error::database_error)?;
585
586 Ok(rows
587 .into_iter()
588 .map(|r| QueueTaskCounts {
589 queue: r.queue,
590 backlog: r.backlog,
591 in_flight: r.in_flight,
592 scheduled: r.scheduled,
593 })
594 .collect())
595}
596
597#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
605pub struct ConcurrencyKeyStats {
606 pub key: String,
608 pub task_type: String,
610 pub max_concurrent: i32,
612 pub in_flight: i64,
614 pub pending: i64,
617}
618
619pub async fn concurrency_key_stats(
628 conn: &mut AsyncPgConnection,
629 queues: &[String],
630) -> HarvestResult<Vec<ConcurrencyKeyStats>> {
631 #[derive(diesel::QueryableByName)]
632 struct Row {
633 #[diesel(sql_type = diesel::sql_types::Text)]
634 key: String,
635 #[diesel(sql_type = diesel::sql_types::Text)]
636 task_type: String,
637 #[diesel(sql_type = diesel::sql_types::Integer)]
638 max_concurrent: i32,
639 #[diesel(sql_type = diesel::sql_types::BigInt)]
640 in_flight: i64,
641 #[diesel(sql_type = diesel::sql_types::BigInt)]
642 pending: i64,
643 }
644
645 let rows: Vec<Row> = diesel::sql_query(
646 "SELECT \
647 concurrency_key AS key, \
648 task_type, \
649 MAX(concurrency_cap)::INT4 AS max_concurrent, \
650 COUNT(*) FILTER (WHERE state = 'RUNNING' AND worker_id IS NOT NULL) AS in_flight, \
651 COUNT(*) FILTER (WHERE state = 'PENDING') AS pending \
652 FROM harvest_task_queue \
653 WHERE concurrency_key IS NOT NULL \
654 AND concurrency_cap IS NOT NULL \
655 AND queue_name = ANY($1) \
656 AND (state = 'PENDING' OR (state = 'RUNNING' AND worker_id IS NOT NULL)) \
657 GROUP BY concurrency_key, task_type",
658 )
659 .bind::<diesel::sql_types::Array<diesel::sql_types::Text>, _>(queues)
660 .load(conn)
661 .await
662 .map_err(crate::error::database_error)?;
663
664 Ok(rows
665 .into_iter()
666 .map(|r| ConcurrencyKeyStats {
667 key: r.key,
668 task_type: r.task_type,
669 max_concurrent: r.max_concurrent,
670 in_flight: r.in_flight,
671 pending: r.pending,
672 })
673 .collect())
674}
675
676pub(crate) async fn task_state_for_update(
688 conn: &mut AsyncPgConnection,
689 task_id: Uuid,
690) -> HarvestResult<Option<String>> {
691 use crate::schema::harvest_task_queue::dsl;
692
693 dsl::harvest_task_queue
694 .find(task_id)
695 .for_update()
696 .select(dsl::state)
697 .first::<String>(conn)
698 .await
699 .optional()
700 .map_err(crate::error::database_error)
701}
702
703pub async fn complete_task(
706 conn: &mut AsyncPgConnection,
707 task_id: Uuid,
708 output: serde_json::Value,
709) -> HarvestResult<()> {
710 use crate::schema::harvest_task_queue::dsl;
711
712 let updated = diesel::update(
713 dsl::harvest_task_queue
714 .find(task_id)
715 .filter(dsl::state.eq("RUNNING")),
716 )
717 .set((
718 dsl::state.eq("COMPLETED"),
719 dsl::output.eq(Some(output)),
720 dsl::heartbeat_details.eq(None::<serde_json::Value>),
721 dsl::error.eq(None::<String>),
722 dsl::completed_at.eq(Some(Utc::now())),
723 ))
724 .execute(conn)
725 .await
726 .map_err(crate::error::database_error)?;
727
728 if updated == 0 {
729 return Err(crate::error::HarvestError::NotFound(format!(
730 "task queue item {task_id} is not running"
731 )));
732 }
733
734 Ok(())
735}
736
737pub async fn fail_task(
747 conn: &mut AsyncPgConnection,
748 task_id: Uuid,
749 error: &str,
750) -> HarvestResult<()> {
751 use crate::schema::harvest_task_queue::dsl;
752
753 let updated = diesel::update(
754 dsl::harvest_task_queue
755 .find(task_id)
756 .filter(dsl::state.eq_any(["PENDING", "RUNNING"])),
757 )
758 .set((
759 dsl::state.eq("FAILED"),
760 dsl::error.eq(Some(error)),
761 dsl::heartbeat_details.eq(None::<serde_json::Value>),
762 dsl::completed_at.eq(Some(Utc::now())),
763 ))
764 .execute(conn)
765 .await
766 .map_err(crate::error::database_error)?;
767
768 if updated == 0 {
769 return Err(crate::error::HarvestError::NotFound(format!(
770 "task queue item {task_id} is not pending or running"
771 )));
772 }
773
774 Ok(())
775}
776
777pub async fn fail_open_tasks_for_execution(
788 conn: &mut AsyncPgConnection,
789 exec_id: ExecutionId,
790 error: &str,
791) -> HarvestResult<usize> {
792 use crate::schema::harvest_task_queue::dsl;
793
794 diesel::update(
795 dsl::harvest_task_queue
796 .filter(dsl::workflow_exec_id.eq(Some(exec_id.as_uuid())))
797 .filter(dsl::state.eq_any(["PENDING", "RUNNING"])),
798 )
799 .set((
800 dsl::state.eq("FAILED"),
801 dsl::error.eq(Some(error.to_string())),
802 dsl::heartbeat_details.eq(None::<serde_json::Value>),
803 dsl::completed_at.eq(Some(Utc::now())),
804 ))
805 .execute(conn)
806 .await
807 .map_err(crate::error::database_error)
808}
809
810pub async fn cancel_open_tasks_for_execution(
820 conn: &mut AsyncPgConnection,
821 exec_id: ExecutionId,
822 reason: &str,
823) -> HarvestResult<usize> {
824 use crate::schema::harvest_task_queue::dsl;
825
826 diesel::update(
827 dsl::harvest_task_queue
828 .filter(dsl::workflow_exec_id.eq(Some(exec_id.as_uuid())))
829 .filter(dsl::state.eq_any(["PENDING", "RUNNING"])),
830 )
831 .set((
832 dsl::state.eq("CANCELLED"),
833 dsl::worker_id.eq(None::<String>),
834 dsl::error.eq(Some(reason.to_string())),
835 dsl::heartbeat_details.eq(None::<serde_json::Value>),
836 dsl::completed_at.eq(Some(Utc::now())),
837 ))
838 .execute(conn)
839 .await
840 .map_err(crate::error::database_error)
841}
842
843pub async fn queue_depths(
853 conn: &mut AsyncPgConnection,
854 queues: &[String],
855) -> HarvestResult<Vec<(String, i64)>> {
856 #[derive(diesel::QueryableByName)]
857 struct Row {
858 #[diesel(sql_type = diesel::sql_types::Text)]
859 queue_name: String,
860 #[diesel(sql_type = diesel::sql_types::BigInt)]
861 depth: i64,
862 }
863
864 let rows: Vec<Row> = diesel::sql_query(
865 "SELECT queue_name, COUNT(*)::BIGINT AS depth \
866 FROM harvest_task_queue \
867 WHERE queue_name = ANY($1) \
868 AND state = 'PENDING' \
869 AND scheduled_at <= NOW() \
870 GROUP BY queue_name",
871 )
872 .bind::<diesel::sql_types::Array<diesel::sql_types::Text>, _>(queues)
873 .load(conn)
874 .await
875 .map_err(crate::error::database_error)?;
876
877 Ok(rows.into_iter().map(|r| (r.queue_name, r.depth)).collect())
878}
879
880pub async fn record_heartbeat(
886 conn: &mut AsyncPgConnection,
887 task_id: Uuid,
888 details: serde_json::Value,
889) -> HarvestResult<()> {
890 use crate::schema::harvest_task_queue::dsl;
891
892 let updated = diesel::update(
893 dsl::harvest_task_queue
894 .find(task_id)
895 .filter(dsl::state.eq("RUNNING")),
896 )
897 .set((
898 dsl::last_heartbeat_at.eq(Some(Utc::now())),
899 dsl::heartbeat_details.eq(Some(details)),
900 ))
901 .execute(conn)
902 .await
903 .map_err(crate::error::database_error)?;
904
905 if updated == 0 {
906 return Err(crate::error::HarvestError::NotFound(format!(
907 "task queue item {task_id} is not running"
908 )));
909 }
910
911 Ok(())
912}
913
914pub async fn requeue_for_retry(
930 conn: &mut AsyncPgConnection,
931 task_id: Uuid,
932 delay: Duration,
933 previous_error: &str,
934) -> HarvestResult<()> {
935 use crate::schema::harvest_task_queue::dsl;
936
937 let next_run = Utc::now() + delay;
938
939 let queue_name = diesel::update(
940 dsl::harvest_task_queue
941 .find(task_id)
942 .filter(dsl::state.eq("RUNNING")),
943 )
944 .set((
945 dsl::state.eq("PENDING"),
946 dsl::worker_id.eq(None::<String>),
947 dsl::started_at.eq(None::<chrono::DateTime<Utc>>),
948 dsl::last_heartbeat_at.eq(None::<chrono::DateTime<Utc>>),
949 dsl::crash_strikes.eq(0),
950 dsl::scheduled_at.eq(next_run),
951 dsl::error.eq(Some(previous_error)),
952 ))
953 .returning(dsl::queue_name)
954 .get_result::<String>(conn)
955 .await
956 .optional()
957 .map_err(crate::error::database_error)?
958 .ok_or_else(|| {
959 crate::error::HarvestError::NotFound(format!("task queue item {task_id} is not running"))
960 })?;
961
962 crate::notify::notify_task_enqueued(conn, &queue_name, task_id).await?;
963
964 Ok(())
965}
966
967pub async fn reschedule_task(
977 conn: &mut AsyncPgConnection,
978 task_id: Uuid,
979 scheduled_at: chrono::DateTime<Utc>,
980) -> HarvestResult<()> {
981 use crate::schema::harvest_task_queue::dsl;
982
983 let queue_name = diesel::update(
984 dsl::harvest_task_queue
985 .find(task_id)
986 .filter(dsl::state.eq("RUNNING")),
987 )
988 .set((
989 dsl::state.eq("PENDING"),
990 dsl::worker_id.eq(None::<String>),
991 dsl::started_at.eq(None::<chrono::DateTime<Utc>>),
992 dsl::last_heartbeat_at.eq(None::<chrono::DateTime<Utc>>),
993 dsl::crash_strikes.eq(0),
998 dsl::scheduled_at.eq(scheduled_at),
999 ))
1000 .returning(dsl::queue_name)
1001 .get_result::<String>(conn)
1002 .await
1003 .optional()
1004 .map_err(crate::error::database_error)?
1005 .ok_or_else(|| {
1006 crate::error::HarvestError::NotFound(format!("task queue item {task_id} is not running"))
1007 })?;
1008
1009 crate::notify::notify_task_enqueued(conn, &queue_name, task_id).await?;
1010
1011 Ok(())
1012}
1013
1014pub async fn defer_rate_limited_task(
1028 conn: &mut AsyncPgConnection,
1029 task_id: Uuid,
1030 scheduled_at: chrono::DateTime<Utc>,
1031) -> HarvestResult<()> {
1032 use crate::schema::harvest_task_queue::dsl;
1033
1034 let queue_name = diesel::update(
1035 dsl::harvest_task_queue
1036 .find(task_id)
1037 .filter(dsl::state.eq("RUNNING")),
1038 )
1039 .set((
1040 dsl::state.eq("PENDING"),
1041 dsl::worker_id.eq(None::<String>),
1042 dsl::started_at.eq(None::<chrono::DateTime<Utc>>),
1043 dsl::last_heartbeat_at.eq(None::<chrono::DateTime<Utc>>),
1044 dsl::crash_strikes.eq(0),
1045 dsl::attempt.eq(diesel::dsl::sql::<diesel::sql_types::Integer>(
1048 "GREATEST(attempt - 1, 0)",
1049 )),
1050 dsl::scheduled_at.eq(scheduled_at),
1051 ))
1052 .returning(dsl::queue_name)
1053 .get_result::<String>(conn)
1054 .await
1055 .optional()
1056 .map_err(crate::error::database_error)?
1057 .ok_or_else(|| {
1058 crate::error::HarvestError::NotFound(format!("task queue item {task_id} is not running"))
1059 })?;
1060
1061 crate::notify::notify_task_enqueued(conn, &queue_name, task_id).await?;
1062
1063 Ok(())
1064}
1065
1066#[derive(Debug, Clone, Copy)]
1072pub struct StickyHint<'a> {
1073 pub worker_id: &'a str,
1074 pub timeout: StdDuration,
1075}
1076
1077impl<'a> StickyHint<'a> {
1078 #[must_use]
1080 pub const fn new(worker_id: &'a str, timeout: StdDuration) -> Self {
1081 Self { worker_id, timeout }
1082 }
1083
1084 fn chrono_timeout(self) -> HarvestResult<chrono::Duration> {
1085 chrono::Duration::from_std(self.timeout).map_err(|_| {
1086 crate::error::HarvestError::Config(
1087 "sticky_timeout exceeds chrono duration range".to_string(),
1088 )
1089 })
1090 }
1091}
1092
1093pub async fn set_task_sticky_affinity(
1107 conn: &mut AsyncPgConnection,
1108 task_id: Uuid,
1109 sticky: Option<StickyHint<'_>>,
1110) -> HarvestResult<()> {
1111 let updated = if let Some(hint) = sticky {
1115 let timeout = hint.chrono_timeout()?;
1116 diesel::sql_query(
1117 "UPDATE harvest_task_queue \
1118 SET sticky_worker_id = $2, \
1119 sticky_until = NOW() + $3, \
1120 sticky_timeout = $3 \
1121 WHERE id = $1",
1122 )
1123 .bind::<diesel::sql_types::Uuid, _>(task_id)
1124 .bind::<diesel::sql_types::Text, _>(hint.worker_id)
1125 .bind::<diesel::sql_types::Interval, _>(timeout)
1126 .execute(conn)
1127 .await
1128 .map_err(crate::error::database_error)?
1129 } else {
1130 use crate::schema::harvest_task_queue::dsl;
1131 diesel::update(dsl::harvest_task_queue.find(task_id))
1132 .set((
1133 dsl::sticky_worker_id.eq(None::<String>),
1134 dsl::sticky_until.eq(None::<chrono::DateTime<Utc>>),
1135 dsl::sticky_timeout.eq(None::<chrono::Duration>),
1136 ))
1137 .execute(conn)
1138 .await
1139 .map_err(crate::error::database_error)?
1140 };
1141
1142 if updated == 0 {
1143 return Err(crate::error::HarvestError::NotFound(format!(
1144 "task queue item {task_id} does not exist"
1145 )));
1146 }
1147
1148 Ok(())
1149}
1150
1151pub async fn park_workflow_task(
1166 conn: &mut AsyncPgConnection,
1167 task_id: Uuid,
1168 sticky: Option<StickyHint<'_>>,
1169) -> HarvestResult<()> {
1170 let updated = if let Some(hint) = sticky {
1175 let timeout = hint.chrono_timeout()?;
1176 diesel::sql_query(
1177 "UPDATE harvest_task_queue \
1178 SET worker_id = NULL, \
1179 started_at = NULL, \
1180 sticky_worker_id = $2, \
1181 sticky_until = NOW() + $3, \
1182 sticky_timeout = $3 \
1183 WHERE id = $1 \
1184 AND task_type = 'workflow' \
1185 AND state = 'RUNNING'",
1186 )
1187 .bind::<diesel::sql_types::Uuid, _>(task_id)
1188 .bind::<diesel::sql_types::Text, _>(hint.worker_id)
1189 .bind::<diesel::sql_types::Interval, _>(timeout)
1190 .execute(conn)
1191 .await
1192 .map_err(crate::error::database_error)?
1193 } else {
1194 use crate::schema::harvest_task_queue::dsl;
1200 diesel::update(
1201 dsl::harvest_task_queue
1202 .find(task_id)
1203 .filter(dsl::task_type.eq(TaskType::Workflow.as_str()))
1204 .filter(dsl::state.eq("RUNNING")),
1205 )
1206 .set((
1207 dsl::worker_id.eq(None::<String>),
1208 dsl::started_at.eq(None::<chrono::DateTime<Utc>>),
1209 dsl::sticky_worker_id.eq(None::<String>),
1210 dsl::sticky_until.eq(None::<chrono::DateTime<Utc>>),
1211 dsl::sticky_timeout.eq(None::<chrono::Duration>),
1212 ))
1213 .execute(conn)
1214 .await
1215 .map_err(crate::error::database_error)?
1216 };
1217
1218 if updated == 0 {
1219 return Err(crate::error::HarvestError::NotFound(format!(
1220 "workflow task queue item {task_id} is not running"
1221 )));
1222 }
1223
1224 Ok(())
1225}
1226
1227pub async fn wake_workflow_task(
1244 conn: &mut AsyncPgConnection,
1245 exec_id: ExecutionId,
1246) -> HarvestResult<()> {
1247 let queue_names: Vec<String> = {
1251 use diesel::deserialize::QueryableByName;
1252 use diesel::sql_types::Text;
1253
1254 #[derive(QueryableByName)]
1255 struct QueueNameRow {
1256 #[diesel(sql_type = Text)]
1257 queue_name: String,
1258 }
1259
1260 let rows: Vec<QueueNameRow> = diesel::sql_query(
1261 "UPDATE harvest_task_queue \
1262 SET state = 'PENDING', \
1263 worker_id = NULL, \
1264 started_at = NULL, \
1265 scheduled_at = $2, \
1266 activity_name = NULL, \
1267 sticky_until = CASE \
1268 WHEN sticky_worker_id IS NOT NULL AND sticky_timeout IS NOT NULL \
1269 THEN NOW() + sticky_timeout \
1270 ELSE sticky_until \
1271 END \
1272 WHERE workflow_exec_id = $1 \
1273 AND task_type = 'workflow' \
1274 AND ( \
1275 (state = 'RUNNING' AND worker_id IS NULL AND started_at IS NULL) \
1276 OR (state = 'PENDING' AND scheduled_at > $2 AND activity_name = 'mixed_signal_suspension') \
1277 ) \
1278 RETURNING queue_name",
1279 )
1280 .bind::<diesel::sql_types::Uuid, _>(exec_id.as_uuid())
1281 .bind::<diesel::sql_types::Timestamptz, _>(Utc::now() - IMMEDIATE_SCHEDULE_SKEW_ALLOWANCE)
1282 .load(conn)
1283 .await
1284 .map_err(crate::error::database_error)?;
1285
1286 rows.into_iter().map(|r| r.queue_name).collect()
1287 };
1288
1289 let already_due_queue_names: Vec<String> = {
1296 use diesel::deserialize::QueryableByName;
1297 use diesel::sql_types::Text;
1298
1299 #[derive(QueryableByName)]
1300 struct QueueNameRow {
1301 #[diesel(sql_type = Text)]
1302 queue_name: String,
1303 }
1304
1305 let rows: Vec<QueueNameRow> = diesel::sql_query(
1306 "SELECT DISTINCT queue_name FROM harvest_task_queue \
1307 WHERE workflow_exec_id = $1 \
1308 AND task_type = 'workflow' \
1309 AND state = 'PENDING' \
1310 AND scheduled_at <= $2",
1311 )
1312 .bind::<diesel::sql_types::Uuid, _>(exec_id.as_uuid())
1313 .bind::<diesel::sql_types::Timestamptz, _>(Utc::now())
1314 .load(conn)
1315 .await
1316 .map_err(crate::error::database_error)?;
1317
1318 rows.into_iter().map(|r| r.queue_name).collect()
1319 };
1320
1321 let mut queue_names = queue_names;
1322 queue_names.extend(already_due_queue_names);
1323 queue_names.sort();
1324 queue_names.dedup();
1325
1326 crate::notify::notify_tasks_enqueued(conn, &queue_names, Uuid::nil()).await?;
1327
1328 Ok(())
1329}
1330
1331pub async fn update_task_priority(
1346 conn: &mut AsyncPgConnection,
1347 task_id: Uuid,
1348 priority: Priority,
1349) -> HarvestResult<bool> {
1350 use crate::schema::harvest_task_queue::dsl;
1351
1352 let updated = diesel::update(
1353 dsl::harvest_task_queue
1354 .find(task_id)
1355 .filter(dsl::state.eq_any(["PENDING", "RUNNING"])),
1356 )
1357 .set(dsl::priority.eq(priority.as_i32()))
1358 .execute(conn)
1359 .await
1360 .map_err(crate::error::database_error)?;
1361
1362 Ok(updated > 0)
1363}
1364
1365pub async fn task_exists(conn: &mut AsyncPgConnection, task_id: Uuid) -> HarvestResult<bool> {
1371 use crate::schema::harvest_task_queue::dsl;
1372
1373 let found: Option<Uuid> = dsl::harvest_task_queue
1374 .filter(dsl::id.eq(task_id))
1375 .select(dsl::id)
1376 .first::<Uuid>(conn)
1377 .await
1378 .optional()
1379 .map_err(crate::error::database_error)?;
1380
1381 Ok(found.is_some())
1382}
1383
1384pub async fn check_throttled_keys(
1392 conn: &mut AsyncPgConnection,
1393 queues: &[String],
1394) -> HarvestResult<Vec<String>> {
1395 #[derive(diesel::QueryableByName)]
1396 struct Row {
1397 #[diesel(sql_type = diesel::sql_types::Text)]
1398 rate_limit_key: String,
1399 }
1400
1401 let rows: Vec<Row> = diesel::sql_query(
1402 "SELECT DISTINCT q.rate_limit_key \
1403 FROM harvest_task_queue q \
1404 JOIN harvest_rate_limit_buckets b ON b.key = q.rate_limit_key \
1405 WHERE q.queue_name = ANY($1) \
1406 AND q.state = 'PENDING' \
1407 AND q.scheduled_at <= NOW() \
1408 AND LEAST(b.burst, b.tokens + EXTRACT(EPOCH FROM (NOW() - b.last_refilled_at)) * b.refill_rate) < 1.0"
1409 )
1410 .bind::<diesel::sql_types::Array<diesel::sql_types::Text>, _>(queues)
1411 .load(conn)
1412 .await
1413 .map_err(crate::error::database_error)?;
1414
1415 Ok(rows.into_iter().map(|r| r.rate_limit_key).collect())
1416}
1417
1418pub async fn try_consume_rate_limit_token(
1446 conn: &mut AsyncPgConnection,
1447 key: &str,
1448) -> HarvestResult<bool> {
1449 #[derive(diesel::QueryableByName)]
1450 struct Outcome {
1451 #[diesel(sql_type = diesel::sql_types::Bool)]
1452 debited: bool,
1453 }
1454
1455 let outcome: Option<Outcome> = diesel::sql_query(
1459 "WITH debited AS ( \
1460 UPDATE harvest_rate_limit_buckets \
1461 SET tokens = LEAST(burst, tokens + EXTRACT(EPOCH FROM (NOW() - last_refilled_at)) * refill_rate) - 1.0, \
1462 last_refilled_at = NOW() \
1463 WHERE key = $1 \
1464 AND LEAST(burst, tokens + EXTRACT(EPOCH FROM (NOW() - last_refilled_at)) * refill_rate) >= 1.0 \
1465 RETURNING key \
1466 ) \
1467 SELECT EXISTS (SELECT 1 FROM debited) AS debited",
1468 )
1469 .bind::<diesel::sql_types::Text, _>(key)
1470 .get_result(conn)
1471 .await
1472 .optional()
1473 .map_err(crate::error::database_error)?;
1474
1475 Ok(outcome.is_some_and(|o| o.debited))
1477}
1478
1479pub async fn refund_rate_limit_token(conn: &mut AsyncPgConnection, key: &str) -> HarvestResult<()> {
1495 diesel::sql_query(
1496 "UPDATE harvest_rate_limit_buckets \
1497 SET tokens = LEAST(burst, tokens + EXTRACT(EPOCH FROM (NOW() - last_refilled_at)) * refill_rate + 1.0), \
1498 last_refilled_at = NOW() \
1499 WHERE key = $1",
1500 )
1501 .bind::<diesel::sql_types::Text, _>(key)
1502 .execute(conn)
1503 .await
1504 .map_err(crate::error::database_error)?;
1505 Ok(())
1506}
1507
1508#[cfg(test)]
1513mod tests {
1514 use super::*;
1515
1516 #[test]
1517 fn enqueue_params_builds_correctly() {
1518 let params = EnqueueParams::new(
1519 "email-queue",
1520 TaskType::Activity,
1521 serde_json::json!({"to": "alice"}),
1522 );
1523
1524 assert_eq!(params.queue_name, "email-queue");
1525 assert_eq!(params.task_type, TaskType::Activity);
1526 assert_eq!(params.input, serde_json::json!({"to": "alice"}));
1527 assert_eq!(params.priority, 0);
1528 assert_eq!(params.max_attempts, 3);
1529 assert!(params.workflow_exec_id.is_none());
1530 assert!(params.activity_name.is_none());
1531 assert!(params.heartbeat_timeout.is_none());
1532 assert!(params.start_to_close.is_none());
1533 assert!(params.schedule_to_start.is_none());
1534 assert!(params.retry_policy.is_none());
1535 assert!(params.trace_context.is_none());
1536 }
1537
1538 #[test]
1539 fn enqueue_params_with_trace_context_attaches_carrier() {
1540 let carrier = TraceContextCarrier::from_traceparent("00-abcd-ef01-01");
1541 let params = EnqueueParams::new("billing", TaskType::Workflow, serde_json::json!(null))
1542 .with_trace_context(carrier.clone());
1543
1544 assert_eq!(params.trace_context, Some(carrier));
1545 }
1546
1547 #[test]
1548 fn task_type_display() {
1549 assert_eq!(TaskType::Workflow.as_str(), "workflow");
1550 assert_eq!(TaskType::Activity.as_str(), "activity");
1551 assert_eq!(format!("{}", TaskType::Workflow), "workflow");
1552 assert_eq!(format!("{}", TaskType::Activity), "activity");
1553 }
1554
1555 #[test]
1556 fn enqueue_params_with_overrides() {
1557 let mut params = EnqueueParams::new("billing", TaskType::Workflow, serde_json::json!(null));
1558 params.priority = 10;
1559 params.max_attempts = 5;
1560 params.workflow_exec_id = Some(Uuid::new_v4());
1561
1562 assert_eq!(params.priority, 10);
1563 assert_eq!(params.max_attempts, 5);
1564 assert!(params.workflow_exec_id.is_some());
1565 }
1566
1567 #[test]
1568 fn enqueue_params_defaults_have_no_sticky_pin() {
1569 let params = EnqueueParams::new("default", TaskType::Workflow, serde_json::json!(null));
1570 assert!(params.sticky_worker_id.is_none());
1571 assert!(params.sticky_timeout.is_none());
1572 }
1573
1574 #[test]
1575 fn enqueue_params_with_sticky_sets_both_fields() {
1576 let params = EnqueueParams::new("default", TaskType::Workflow, serde_json::json!(null))
1577 .with_sticky("worker-42", StdDuration::from_secs(7));
1578 assert_eq!(params.sticky_worker_id.as_deref(), Some("worker-42"));
1579 assert_eq!(params.sticky_timeout, Some(StdDuration::from_secs(7)));
1580 }
1581
1582 #[test]
1583 fn sticky_hint_constructs_with_fields() {
1584 let hint = StickyHint::new("w1", StdDuration::from_secs(3));
1585 assert_eq!(hint.worker_id, "w1");
1586 assert_eq!(hint.timeout, StdDuration::from_secs(3));
1587 }
1588
1589 #[test]
1590 fn sticky_hint_rejects_out_of_range_duration() {
1591 let hint = StickyHint::new("w1", StdDuration::from_secs(u64::MAX));
1592 assert!(hint.chrono_timeout().is_err());
1593 }
1594
1595 #[test]
1596 fn enqueue_params_concurrency_fields_default_to_none() {
1597 let params = EnqueueParams::new("default", TaskType::Activity, serde_json::json!(null));
1598 assert!(params.concurrency_key.is_none());
1599 assert!(params.max_concurrent.is_none());
1600 }
1601
1602 #[test]
1603 fn enqueue_params_schedule_to_close_at_defaults_to_none() {
1604 let params = EnqueueParams::new("default", TaskType::Activity, serde_json::json!(null));
1605 assert!(
1606 params.schedule_to_close_at.is_none(),
1607 "schedule_to_close_at must default to None (unbounded)"
1608 );
1609 }
1610
1611 #[test]
1612 fn enqueue_params_schedule_to_close_at_can_be_set() {
1613 let deadline = Utc::now() + Duration::seconds(300);
1614 let mut params = EnqueueParams::new("default", TaskType::Activity, serde_json::json!(null));
1615 params.schedule_to_close_at = Some(deadline);
1616 assert_eq!(params.schedule_to_close_at, Some(deadline));
1617 }
1618
1619 #[test]
1620 fn enqueue_params_concurrency_fields_set_manually() {
1621 let mut params = EnqueueParams::new("default", TaskType::Activity, serde_json::json!(null));
1622 params.concurrency_key = Some("stripe".to_string());
1623 params.max_concurrent = Some(5);
1624 assert_eq!(params.concurrency_key.as_deref(), Some("stripe"));
1625 assert_eq!(params.max_concurrent, Some(5));
1626 }
1627}