nats_counters/
counter.rs

1// Copyright 2025 Synadia Communications Inc.
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14//! JetStream counter implementation.
15
16use crate::{
17    COUNTER_INCREMENT_HEADER, Entry, Result,
18    errors::{CounterError, CounterErrorKind},
19    parser::{
20        parse_counter_value, parse_counter_value_from_string, parse_increment, parse_sources,
21    },
22};
23use async_nats::{
24    HeaderMap,
25    jetstream::{self, stream::Stream},
26    subject::ToSubject,
27};
28use futures_util::{Stream as FutureStream, TryStreamExt};
29use jetstream_extra::batch_fetch::BatchFetchExt;
30use num_bigint::BigInt;
31use std::pin::Pin;
32
33/// Implementation of distributed counters using a JetStream stream.
34pub struct Counter {
35    stream: Stream,
36    context: jetstream::Context,
37}
38
39impl Counter {
40    /// Creates a counter from an existing stream.
41    pub async fn from_stream(context: jetstream::Context, mut stream: Stream) -> Result<Self> {
42        let info = stream
43            .info()
44            .await
45            .map_err(|e| CounterError::with_source(CounterErrorKind::Request, e))?;
46
47        if !info.config.allow_message_counter {
48            return Err(CounterError::new(CounterErrorKind::CounterNotEnabled));
49        }
50
51        if !info.config.allow_direct {
52            return Err(CounterError::new(CounterErrorKind::DirectAccessRequired));
53        }
54
55        Ok(Self { stream, context })
56    }
57
58    /// Adds a value to the counter for the given subject.
59    pub async fn add<S, V>(&self, subject: S, value: V) -> Result<BigInt>
60    where
61        S: ToSubject,
62        V: Into<BigInt> + Send,
63    {
64        let value = value.into();
65
66        let mut headers = HeaderMap::new();
67        headers.insert(COUNTER_INCREMENT_HEADER, value.to_string());
68
69        let ack = self
70            .context
71            .publish_with_headers(subject, headers, vec![].into())
72            .await
73            .map_err(|e| CounterError::with_source(CounterErrorKind::Publish, e))?
74            .await
75            .map_err(|e| CounterError::with_source(CounterErrorKind::Publish, e))?;
76
77        parse_counter_value_from_string(ack.value)
78    }
79
80    /// Gets the counter entry for a subject, including sources and last increment.
81    ///
82    /// Returns complete information about a counter including its current value,
83    /// source breakdown (if using stream sourcing), and the last increment value.
84    pub async fn get<S: ToSubject>(&self, subject: S) -> Result<Entry> {
85        let subject_str = subject.to_subject();
86
87        let msg = self
88            .stream
89            .get_last_raw_message_by_subject(&subject_str)
90            .await
91            .map_err(|e| match e.kind() {
92                jetstream::stream::LastRawMessageErrorKind::NoMessageFound => {
93                    CounterError::new(CounterErrorKind::NoCounterForSubject)
94                }
95                _ => CounterError::with_source(CounterErrorKind::Stream, e),
96            })?;
97
98        let value = parse_counter_value(&msg.payload)?;
99        let sources = parse_sources(&msg.headers)?;
100        let increment = parse_increment(&msg.headers)?;
101
102        Ok(Entry {
103            subject: subject_str.to_string(),
104            value,
105            sources,
106            increment,
107        })
108    }
109
110    /// Loads just the counter value for a subject.
111    ///
112    /// This is a convenience method that returns only the value,
113    /// without sources or increment information.
114    pub async fn load<S: ToSubject>(&self, subject: S) -> Result<BigInt> {
115        Ok(self.get(subject).await?.value)
116    }
117
118    /// Increments the counter by 1 for the given subject.
119    /// This is a convenience method for the common case of incrementing by 1.
120    pub async fn increment<S>(&self, subject: S) -> Result<BigInt>
121    where
122        S: ToSubject,
123    {
124        self.add(subject, 1).await
125    }
126
127    /// Decrements the counter by 1 for the given subject.
128    /// This is a convenience method for the common case of decrementing by 1.
129    pub async fn decrement<S>(&self, subject: S) -> Result<BigInt>
130    where
131        S: ToSubject,
132    {
133        self.add(subject, -1).await
134    }
135
136    /// Gets multiple counter entries matching the given subjects.
137    ///
138    /// Efficiently fetches multiple counters in a batch operation.
139    /// Returns a stream of entries.
140    ///
141    /// Note: Subjects should be exact subjects, not patterns. Non-existent
142    /// counters are silently skipped in the results.
143    pub async fn get_multiple(
144        &self,
145        subjects: Vec<String>,
146    ) -> Result<Pin<Box<dyn FutureStream<Item = Result<Entry>> + Send + '_>>> {
147        // Validate subjects
148        if subjects.is_empty() {
149            return Err(CounterError::new(CounterErrorKind::InvalidCounterValue));
150        }
151
152        for subject in &subjects {
153            if subject.is_empty() {
154                return Err(CounterError::new(CounterErrorKind::InvalidCounterValue));
155            }
156        }
157
158        let stream_name = self.stream.cached_info().config.name.clone();
159
160        let messages = self
161            .context
162            .get_last_messages_for(&stream_name)
163            .subjects(subjects)
164            .send()
165            .await
166            .map_err(|e| CounterError::with_source(CounterErrorKind::Stream, e))?;
167
168        let entries = messages
169            .map_err(|e| CounterError::with_source(CounterErrorKind::Stream, e))
170            .and_then(|msg| async move {
171                let value = parse_counter_value(&msg.payload)?;
172                let sources = parse_sources(&msg.headers)?;
173                let increment = parse_increment(&msg.headers)?;
174
175                Ok(Entry {
176                    subject: msg.subject.to_string(),
177                    value,
178                    sources,
179                    increment,
180                })
181            });
182
183        Ok(Box::pin(entries)
184            as Pin<
185                Box<dyn FutureStream<Item = Result<Entry>> + Send + '_>,
186            >)
187    }
188}