Skip to main content

throttle_net/
multi.rs

1//! Multi-dimensional limits: one request, several independent budgets.
2
3use std::sync::Arc;
4
5use crate::decision::Decision;
6#[cfg(feature = "runtime")]
7use crate::error::ThrottleError;
8use crate::limiter::{Limiter, acquire_all, peek_all};
9
10/// One named, independently-metered dimension: its label and its limiter.
11type Dimension = (Box<str>, Arc<dyn Limiter>);
12
13/// A limiter with several named dimensions, each metered independently.
14///
15/// One outbound call often spends against more than one budget at once. An LLM
16/// request, for instance, counts as one *request*, some number of *input
17/// tokens*, and some number of *output tokens* — each with its own ceiling. A
18/// `MultiLimiter` holds one limiter per dimension and admits a call only when
19/// **every** dimension can afford its share, applying the per-dimension costs
20/// atomically (peek-all-then-commit, like [`Hybrid`](crate::Hybrid)).
21///
22/// Costs are supplied per call as `(dimension, cost)` pairs. A dimension not
23/// named in a call is charged nothing; a name with no matching dimension is
24/// ignored.
25///
26/// Build one with [`MultiLimiter::builder`].
27///
28/// # Examples
29///
30/// ```
31/// # async fn run() -> Result<(), throttle_net::ThrottleError> {
32/// use std::time::Duration;
33/// use throttle_net::{MultiLimiter, Throttle};
34///
35/// let minute = Duration::from_secs(60);
36/// let limiter = MultiLimiter::builder()
37///     .dimension("requests", Throttle::per_duration(60, minute))
38///     .dimension("input_tokens", Throttle::per_duration(100_000, minute))
39///     .dimension("output_tokens", Throttle::per_duration(20_000, minute))
40///     .build();
41///
42/// // A call billed at 1 request, 1500 input tokens, 200 output tokens.
43/// limiter
44///     .acquire_costs(&[
45///         ("requests", 1),
46///         ("input_tokens", 1500),
47///         ("output_tokens", 200),
48///     ])
49///     .await?;
50/// # Ok(())
51/// # }
52/// ```
53#[derive(Clone)]
54pub struct MultiLimiter {
55    dimensions: Arc<[Dimension]>,
56}
57
58/// Resolves the cost charged to a dimension for one call: the matching entry, or
59/// zero when the dimension is not mentioned.
60#[inline]
61fn cost_for(name: &str, costs: &[(&str, u32)]) -> u32 {
62    costs
63        .iter()
64        .copied()
65        .find(|(n, _)| *n == name)
66        .map_or(0, |(_, c)| c)
67}
68
69impl MultiLimiter {
70    /// Starts building a multi-dimensional limiter.
71    #[must_use]
72    pub fn builder() -> MultiLimiterBuilder {
73        MultiLimiterBuilder {
74            dimensions: Vec::new(),
75        }
76    }
77
78    #[inline]
79    fn pairs<'a>(
80        &'a self,
81        costs: &'a [(&'a str, u32)],
82    ) -> impl Iterator<Item = (&'a dyn Limiter, u32)> + Clone {
83        self.dimensions
84            .iter()
85            .map(move |(name, limiter)| (limiter.as_ref(), cost_for(name, costs)))
86    }
87
88    /// Reports whether the call's per-dimension costs would all be granted now,
89    /// without taking anything.
90    ///
91    /// # Examples
92    ///
93    /// ```
94    /// use throttle_net::{MultiLimiter, Throttle};
95    ///
96    /// let limiter = MultiLimiter::builder()
97    ///     .dimension("requests", Throttle::per_second(10))
98    ///     .dimension("tokens", Throttle::per_second(1000))
99    ///     .build();
100    ///
101    /// assert!(limiter.peek_costs(&[("requests", 1), ("tokens", 500)]).is_acquired());
102    /// ```
103    #[inline]
104    #[must_use]
105    pub fn peek_costs(&self, costs: &[(&str, u32)]) -> Decision {
106        peek_all(self.pairs(costs))
107    }
108
109    /// Attempts to charge the call's per-dimension costs without waiting,
110    /// returning whether every dimension granted.
111    ///
112    /// All-or-nothing across dimensions.
113    ///
114    /// # Examples
115    ///
116    /// ```
117    /// use throttle_net::{MultiLimiter, Throttle};
118    ///
119    /// let limiter = MultiLimiter::builder()
120    ///     .dimension("requests", Throttle::per_second(2))
121    ///     .build();
122    ///
123    /// assert!(limiter.try_acquire_costs(&[("requests", 2)]));
124    /// assert!(!limiter.try_acquire_costs(&[("requests", 1)]));
125    /// ```
126    #[inline]
127    #[must_use]
128    pub fn try_acquire_costs(&self, costs: &[(&str, u32)]) -> bool {
129        acquire_all(self.pairs(costs)).is_acquired()
130    }
131
132    /// Returns the tokens available in `dimension` right now, or `None` if there
133    /// is no such dimension.
134    #[must_use]
135    pub fn available(&self, dimension: &str) -> Option<u32> {
136        self.dimensions
137            .iter()
138            .find(|(name, _)| name.as_ref() == dimension)
139            .map(|(_, limiter)| limiter.available())
140    }
141}
142
143#[cfg(feature = "runtime")]
144#[cfg_attr(docsrs, doc(cfg(feature = "runtime")))]
145impl MultiLimiter {
146    /// Charges the call's per-dimension costs, waiting until every dimension can
147    /// afford its share.
148    ///
149    /// This is the headline multi-dimensional operation: it paces the caller
150    /// until all budgets allow the call, then commits all of them together.
151    ///
152    /// # Errors
153    ///
154    /// Returns [`ThrottleError::CostExceedsCapacity`] when some dimension's cost
155    /// exceeds that dimension's capacity, naming that dimension's figures; such a
156    /// call can never succeed.
157    ///
158    /// # Examples
159    ///
160    /// ```
161    /// # async fn run() -> Result<(), throttle_net::ThrottleError> {
162    /// use throttle_net::{MultiLimiter, Throttle};
163    ///
164    /// let limiter = MultiLimiter::builder()
165    ///     .dimension("requests", Throttle::per_second(100))
166    ///     .dimension("tokens", Throttle::per_second(100_000))
167    ///     .build();
168    ///
169    /// limiter.acquire_costs(&[("requests", 1), ("tokens", 1500)]).await?;
170    /// # Ok(())
171    /// # }
172    /// ```
173    pub async fn acquire_costs(&self, costs: &[(&str, u32)]) -> Result<(), ThrottleError> {
174        loop {
175            match acquire_all(self.pairs(costs)) {
176                Decision::Acquired => return Ok(()),
177                Decision::Impossible => return Err(self.capacity_error(costs)),
178                Decision::Retry { after } => crate::rt::sleep(after).await,
179            }
180        }
181    }
182
183    /// Builds the precise [`ThrottleError::CostExceedsCapacity`] for an
184    /// un-grantable call by finding the dimension that cannot afford its cost.
185    fn capacity_error(&self, costs: &[(&str, u32)]) -> ThrottleError {
186        for (name, limiter) in self.dimensions.iter() {
187            let cost = cost_for(name, costs);
188            if cost > limiter.capacity() {
189                return ThrottleError::CostExceedsCapacity {
190                    cost,
191                    capacity: limiter.capacity(),
192                };
193            }
194        }
195        // No dimension over capacity (a static, never-refilling dimension drained
196        // below the cost) — report the first dimension that cannot grant.
197        for (name, limiter) in self.dimensions.iter() {
198            let cost = cost_for(name, costs);
199            if limiter.peek(cost) == Decision::Impossible {
200                return ThrottleError::CostExceedsCapacity {
201                    cost,
202                    capacity: limiter.capacity(),
203                };
204            }
205        }
206        ThrottleError::CostExceedsCapacity {
207            cost: 0,
208            capacity: 0,
209        }
210    }
211}
212
213/// Builder for a [`MultiLimiter`].
214///
215/// Name each dimension with [`dimension`](Self::dimension), then
216/// [`build`](Self::build).
217///
218/// # Examples
219///
220/// ```
221/// use throttle_net::{MultiLimiter, Throttle};
222///
223/// let limiter = MultiLimiter::builder()
224///     .dimension("requests", Throttle::per_second(10))
225///     .build();
226/// # let _ = limiter;
227/// ```
228#[derive(Default)]
229pub struct MultiLimiterBuilder {
230    dimensions: Vec<Dimension>,
231}
232
233impl MultiLimiterBuilder {
234    /// Adds a named dimension backed by `limiter`.
235    ///
236    /// Adding the same name twice keeps both; each is charged independently, so
237    /// prefer distinct names.
238    #[must_use]
239    pub fn dimension(mut self, name: impl Into<Box<str>>, limiter: impl Limiter + 'static) -> Self {
240        self.dimensions.push((name.into(), Arc::new(limiter)));
241        self
242    }
243
244    /// Adds a named dimension backed by an already-shared limiter.
245    #[must_use]
246    pub fn shared(mut self, name: impl Into<Box<str>>, limiter: Arc<dyn Limiter>) -> Self {
247        self.dimensions.push((name.into(), limiter));
248        self
249    }
250
251    /// Builds the [`MultiLimiter`].
252    #[must_use]
253    pub fn build(self) -> MultiLimiter {
254        MultiLimiter {
255            dimensions: self.dimensions.into(),
256        }
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    #![allow(clippy::unwrap_used)]
263
264    use super::MultiLimiter;
265    use crate::throttle::Throttle;
266    use clock_lib::ManualClock;
267    use core::time::Duration;
268    use std::sync::Arc;
269
270    fn assert_send_sync<T: Send + Sync>() {}
271
272    #[test]
273    fn test_multi_limiter_is_send_sync() {
274        assert_send_sync::<MultiLimiter>();
275    }
276
277    #[test]
278    fn test_all_dimensions_must_afford_their_share() {
279        let limiter = MultiLimiter::builder()
280            .dimension("requests", Throttle::per_second(10))
281            .dimension("tokens", Throttle::per_second(1000))
282            .build();
283
284        // Plenty of request headroom, but the token dimension only holds 1000.
285        assert!(limiter.try_acquire_costs(&[("requests", 1), ("tokens", 1000)]));
286        // Tokens are now spent; another token-heavy call is refused even though
287        // requests are fine.
288        assert!(!limiter.try_acquire_costs(&[("requests", 1), ("tokens", 1)]));
289        // ...and the request dimension was not charged for the refused call.
290        assert_eq!(limiter.available("requests"), Some(9));
291    }
292
293    #[test]
294    fn test_unmentioned_dimension_is_not_charged() {
295        let limiter = MultiLimiter::builder()
296            .dimension("requests", Throttle::per_second(2))
297            .dimension("tokens", Throttle::per_second(100))
298            .build();
299
300        // Charge only the request dimension.
301        assert!(limiter.try_acquire_costs(&[("requests", 1)]));
302        assert_eq!(limiter.available("tokens"), Some(100));
303        assert_eq!(limiter.available("requests"), Some(1));
304    }
305
306    #[test]
307    fn test_unknown_dimension_name_is_ignored() {
308        let limiter = MultiLimiter::builder()
309            .dimension("requests", Throttle::per_second(1))
310            .build();
311        assert!(limiter.try_acquire_costs(&[("requests", 1), ("nonexistent", 999)]));
312    }
313
314    #[test]
315    fn test_available_is_none_for_unknown_dimension() {
316        let limiter = MultiLimiter::builder()
317            .dimension("requests", Throttle::per_second(1))
318            .build();
319        assert_eq!(limiter.available("missing"), None);
320    }
321
322    #[test]
323    fn test_peek_costs_does_not_consume() {
324        let limiter = MultiLimiter::builder()
325            .dimension("requests", Throttle::per_second(5))
326            .build();
327        assert!(limiter.peek_costs(&[("requests", 5)]).is_acquired());
328        assert_eq!(limiter.available("requests"), Some(5));
329    }
330
331    #[test]
332    fn test_refill_recovers_each_dimension_under_manual_clock() {
333        let clock = Arc::new(ManualClock::new());
334        let limiter = MultiLimiter::builder()
335            .dimension(
336                "requests",
337                Throttle::per_second(2).with_clock(clock.clone()),
338            )
339            .dimension("tokens", Throttle::per_second(10).with_clock(clock.clone()))
340            .build();
341
342        assert!(limiter.try_acquire_costs(&[("requests", 2), ("tokens", 10)]));
343        assert!(!limiter.try_acquire_costs(&[("requests", 1), ("tokens", 1)]));
344
345        clock.advance(Duration::from_secs(1));
346        assert!(limiter.try_acquire_costs(&[("requests", 2), ("tokens", 10)]));
347    }
348
349    #[cfg(feature = "runtime")]
350    #[tokio::test]
351    async fn test_acquire_costs_errors_and_names_the_overspent_dimension() {
352        use crate::error::ThrottleError;
353
354        let limiter = MultiLimiter::builder()
355            .dimension("requests", Throttle::per_second(100))
356            .dimension("tokens", Throttle::per_second(1000))
357            .build();
358        // 2000 tokens against a 1000 capacity can never succeed.
359        let err = limiter
360            .acquire_costs(&[("requests", 1), ("tokens", 2000)])
361            .await
362            .unwrap_err();
363        assert_eq!(
364            err,
365            ThrottleError::CostExceedsCapacity {
366                cost: 2000,
367                capacity: 1000,
368            }
369        );
370    }
371}