1use crate::{
17 COUNTER_INCREMENT_HEADER, Entry, Result,
18 errors::{CounterError, CounterErrorKind},
19 parser::{
20 parse_counter_value, parse_counter_value_from_string, parse_increment, parse_sources,
21 },
22};
23use async_nats::{
24 HeaderMap,
25 jetstream::{self, stream::Stream},
26 subject::ToSubject,
27};
28use futures_util::{Stream as FutureStream, TryStreamExt};
29use jetstream_extra::batch_fetch::BatchFetchExt;
30use num_bigint::BigInt;
31use std::pin::Pin;
32
33pub struct Counter {
35 stream: Stream,
36 context: jetstream::Context,
37}
38
39impl Counter {
40 pub async fn from_stream(context: jetstream::Context, mut stream: Stream) -> Result<Self> {
42 let info = stream
43 .info()
44 .await
45 .map_err(|e| CounterError::with_source(CounterErrorKind::Request, e))?;
46
47 if !info.config.allow_message_counter {
48 return Err(CounterError::new(CounterErrorKind::CounterNotEnabled));
49 }
50
51 if !info.config.allow_direct {
52 return Err(CounterError::new(CounterErrorKind::DirectAccessRequired));
53 }
54
55 Ok(Self { stream, context })
56 }
57
58 pub async fn add<S, V>(&self, subject: S, value: V) -> Result<BigInt>
60 where
61 S: ToSubject,
62 V: Into<BigInt> + Send,
63 {
64 let value = value.into();
65
66 let mut headers = HeaderMap::new();
67 headers.insert(COUNTER_INCREMENT_HEADER, value.to_string());
68
69 let ack = self
70 .context
71 .publish_with_headers(subject, headers, vec![].into())
72 .await
73 .map_err(|e| CounterError::with_source(CounterErrorKind::Publish, e))?
74 .await
75 .map_err(|e| CounterError::with_source(CounterErrorKind::Publish, e))?;
76
77 parse_counter_value_from_string(ack.value)
78 }
79
80 pub async fn get<S: ToSubject>(&self, subject: S) -> Result<Entry> {
85 let subject_str = subject.to_subject();
86
87 let msg = self
88 .stream
89 .get_last_raw_message_by_subject(&subject_str)
90 .await
91 .map_err(|e| match e.kind() {
92 jetstream::stream::LastRawMessageErrorKind::NoMessageFound => {
93 CounterError::new(CounterErrorKind::NoCounterForSubject)
94 }
95 _ => CounterError::with_source(CounterErrorKind::Stream, e),
96 })?;
97
98 let value = parse_counter_value(&msg.payload)?;
99 let sources = parse_sources(&msg.headers)?;
100 let increment = parse_increment(&msg.headers)?;
101
102 Ok(Entry {
103 subject: subject_str.to_string(),
104 value,
105 sources,
106 increment,
107 })
108 }
109
110 pub async fn load<S: ToSubject>(&self, subject: S) -> Result<BigInt> {
115 Ok(self.get(subject).await?.value)
116 }
117
118 pub async fn increment<S>(&self, subject: S) -> Result<BigInt>
121 where
122 S: ToSubject,
123 {
124 self.add(subject, 1).await
125 }
126
127 pub async fn decrement<S>(&self, subject: S) -> Result<BigInt>
130 where
131 S: ToSubject,
132 {
133 self.add(subject, -1).await
134 }
135
136 pub async fn get_multiple(
144 &self,
145 subjects: Vec<String>,
146 ) -> Result<Pin<Box<dyn FutureStream<Item = Result<Entry>> + Send + '_>>> {
147 if subjects.is_empty() {
149 return Err(CounterError::new(CounterErrorKind::InvalidCounterValue));
150 }
151
152 for subject in &subjects {
153 if subject.is_empty() {
154 return Err(CounterError::new(CounterErrorKind::InvalidCounterValue));
155 }
156 }
157
158 let stream_name = self.stream.cached_info().config.name.clone();
159
160 let messages = self
161 .context
162 .get_last_messages_for(&stream_name)
163 .subjects(subjects)
164 .send()
165 .await
166 .map_err(|e| CounterError::with_source(CounterErrorKind::Stream, e))?;
167
168 let entries = messages
169 .map_err(|e| CounterError::with_source(CounterErrorKind::Stream, e))
170 .and_then(|msg| async move {
171 let value = parse_counter_value(&msg.payload)?;
172 let sources = parse_sources(&msg.headers)?;
173 let increment = parse_increment(&msg.headers)?;
174
175 Ok(Entry {
176 subject: msg.subject.to_string(),
177 value,
178 sources,
179 increment,
180 })
181 });
182
183 Ok(Box::pin(entries)
184 as Pin<
185 Box<dyn FutureStream<Item = Result<Entry>> + Send + '_>,
186 >)
187 }
188}