1use std::num::NonZeroUsize;
2
3use async_trait::async_trait;
4use ave_actors::{
5 Actor, ActorContext, ActorError, ActorPath, Event, Handler, Message,
6 NotPersistentActor, Response,
7};
8use ave_common::{
9 identity::DigestIdentifier,
10 response::{RequestInfo, RequestInfoExtend, RequestState},
11};
12use borsh::{BorshDeserialize, BorshSerialize};
13use lru::LruCache;
14use serde::{Deserialize, Serialize};
15use tracing::{Span, debug, error, info_span, warn};
16
17#[derive(Clone, Debug)]
18pub struct RequestTracking {
19 cache: LruCache<DigestIdentifier, RequestInfo>,
20}
21
22impl RequestTracking {
23 pub fn new(size: usize) -> Self {
24 let size = if size == 0 { 100 } else { size };
25
26 Self {
27 cache: LruCache::new(NonZeroUsize::new(size).expect("size > 0")),
28 }
29 }
30}
31
32impl NotPersistentActor for RequestTracking {}
33
34#[derive(Clone, Debug, Serialize, Deserialize)]
35pub enum RequestTrackingMessage {
36 UpdateState {
37 request_id: DigestIdentifier,
38 state: RequestState,
39 },
40 UpdateVersion {
41 request_id: DigestIdentifier,
42 version: u64,
43 },
44 AllRequests,
45 SearchRequest(DigestIdentifier),
46}
47
48#[derive(Debug, Clone)]
49pub enum RequestTrackingResponse {
50 Ok,
51 AllInfo(Vec<RequestInfoExtend>),
52 Info(RequestInfo),
53 NotFound,
54}
55
56impl Response for RequestTrackingResponse {}
57
58impl Message for RequestTrackingMessage {}
59
60#[async_trait]
61impl Actor for RequestTracking {
62 type Message = RequestTrackingMessage;
63 type Event = RequestTrackingEvent;
64 type Response = RequestTrackingResponse;
65
66 fn get_span(_id: &str, parent_span: Option<Span>) -> tracing::Span {
67 parent_span.map_or_else(
68 || info_span!("RequestTracking"),
69 |parent_span| info_span!(parent: parent_span, "RequestTracking"),
70 )
71 }
72}
73
74#[derive(
75 Debug, Clone, Serialize, Deserialize, BorshDeserialize, BorshSerialize,
76)]
77pub struct RequestTrackingEvent {
78 pub request_id: String,
79 pub subject_id: String,
80 pub sn: Option<u64>,
81 pub error: String,
82 pub who: String,
83 pub abort_type: String,
84}
85
86impl Event for RequestTrackingEvent {}
87
88#[async_trait]
89impl Handler<Self> for RequestTracking {
90 async fn handle_message(
91 &mut self,
92 _sender: ActorPath,
93 msg: RequestTrackingMessage,
94 ctx: &mut ave_actors::ActorContext<Self>,
95 ) -> Result<RequestTrackingResponse, ActorError> {
96 match msg {
97 RequestTrackingMessage::AllRequests => {
98 let count = self.cache.len();
99 debug!(
100 msg_type = "AllRequests",
101 requests_count = count,
102 "Retrieving all tracked requests"
103 );
104 Ok(RequestTrackingResponse::AllInfo(
105 self.cache
106 .iter()
107 .map(|x| RequestInfoExtend {
108 request_id: x.0.to_string(),
109 state: x.1.state.clone(),
110 version: x.1.version,
111 })
112 .collect(),
113 ))
114 }
115 RequestTrackingMessage::UpdateState { request_id, state } => {
116 if let Some(info) = self.cache.get_mut(&request_id) {
117 let old_state = info.state.clone();
118 info.state = state.clone();
119 debug!(
120 msg_type = "UpdateState",
121 request_id = %request_id,
122 old_state = ?old_state,
123 new_state = ?state,
124 "Request state updated"
125 );
126 } else {
127 self.cache.put(
128 request_id.clone(),
129 RequestInfo {
130 state: state.clone(),
131 version: 0,
132 },
133 );
134
135 debug!(
136 msg_type = "UpdateState",
137 request_id = %request_id,
138 state = ?state,
139 "New request tracked"
140 );
141 };
142
143 let event = match state {
144 RequestState::Invalid {
145 subject_id,
146 who,
147 sn,
148 error,
149 } => Some(RequestTrackingEvent {
150 request_id: request_id.to_string(),
151 abort_type: "Invalid".to_string(),
152 error,
153 sn,
154 subject_id,
155 who,
156 }),
157 RequestState::Abort {
158 subject_id,
159 who,
160 sn,
161 error,
162 } => Some(RequestTrackingEvent {
163 request_id: request_id.to_string(),
164 abort_type: "Abort".to_string(),
165 error,
166 sn,
167 subject_id,
168 who,
169 }),
170 _ => None,
171 };
172
173 if let Some(event) = event {
174 self.on_event(event, ctx).await;
175 }
176
177 Ok(RequestTrackingResponse::Ok)
178 }
179 RequestTrackingMessage::UpdateVersion {
180 request_id,
181 version,
182 } => {
183 if let Some(info) = self.cache.get_mut(&request_id) {
184 let old_version = info.version;
185 info.version = version;
186 debug!(
187 msg_type = "UpdateVersion",
188 request_id = %request_id,
189 old_version = old_version,
190 new_version = version,
191 "Request version updated"
192 );
193 } else {
194 warn!(
195 msg_type = "UpdateVersion",
196 request_id = %request_id,
197 version = version,
198 "Request not found in cache"
199 );
200 };
201
202 Ok(RequestTrackingResponse::Ok)
203 }
204 RequestTrackingMessage::SearchRequest(request_id) => {
205 self.cache.get(&request_id).map_or_else(
206 || {
207 debug!(
208 msg_type = "SearchRequest",
209 request_id = %request_id,
210 "Request not found in cache"
211 );
212 Ok(RequestTrackingResponse::NotFound)
213 },
214 |info| {
215 debug!(
216 msg_type = "SearchRequest",
217 request_id = %request_id,
218 state = ?info.state,
219 version = info.version,
220 "Request found in cache"
221 );
222 Ok(RequestTrackingResponse::Info(info.clone()))
223 },
224 )
225 }
226 }
227 }
228
229 async fn on_event(
230 &mut self,
231 event: RequestTrackingEvent,
232 ctx: &mut ActorContext<Self>,
233 ) {
234 if let Err(e) = ctx.publish_event(event).await {
235 error!(
236 error = %e,
237 "Failed to publish event"
238 );
239 ctx.system().crash_system();
240 };
241 }
242}