1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
use std::sync::{
    atomic::{AtomicUsize, Ordering},
    Arc,
};

/// Allocation tracker with total memory limit.
#[derive(Debug, Clone)]
pub struct AllocTracker {
    inner: Arc<AllocTrackerInner>,
}

#[derive(Debug)]
struct AllocTrackerInner {
    bytes_left: AtomicUsize,
}

impl AllocTracker {
    /// Creates a memory allocation tracker with allowed allocation limit.
    pub fn with_limit(bytes_left: usize) -> Self {
        Self {
            inner: Arc::new(AllocTrackerInner {
                bytes_left: AtomicUsize::new(bytes_left),
            }),
        }
    }

    /// Records an allocation of `count` number of `T`, and returns handle of the record.
    ///
    /// Returns an error if the allocation exceeds the current limit.
    pub fn alloc<T>(&self, count: usize) -> Result<AllocHandle, crate::Error> {
        let bytes = count * std::mem::size_of::<T>();
        let result = self.inner.bytes_left.fetch_update(
            Ordering::Relaxed,
            Ordering::Relaxed,
            |bytes_left| bytes_left.checked_sub(bytes),
        );

        match result {
            Ok(prev) => {
                tracing::trace!(bytes, left = prev - bytes, "Created allocation handle");
                Ok(AllocHandle {
                    bytes,
                    inner: Arc::clone(&self.inner),
                })
            }
            Err(left) => {
                tracing::trace!(bytes, left, "Allocation failed");
                Err(crate::Error::OutOfMemory(bytes))
            }
        }
    }

    /// Expands the current limit by `by_bytes` bytes.
    pub fn expand_limit(&self, by_bytes: usize) {
        self.inner.bytes_left.fetch_add(by_bytes, Ordering::Relaxed);
    }

    /// Shrinks the current limit by `by_bytes` bytes.
    ///
    /// Returns an error if the total amount of current allocation doesn't allow shrinking the
    /// limit.
    pub fn shrink_limit(&self, by_bytes: usize) -> Result<(), crate::Error> {
        let result = self.inner.bytes_left.fetch_update(
            Ordering::Relaxed,
            Ordering::Relaxed,
            |bytes_left| bytes_left.checked_sub(by_bytes),
        );

        if result.is_ok() {
            Ok(())
        } else {
            Err(crate::Error::OutOfMemory(by_bytes))
        }
    }
}

/// Allocation handle.
#[derive(Debug)]
pub struct AllocHandle {
    bytes: usize,
    inner: Arc<AllocTrackerInner>,
}

impl Drop for AllocHandle {
    fn drop(&mut self) {
        let bytes = self.bytes;
        let prev = self.inner.bytes_left.fetch_add(bytes, Ordering::Relaxed);
        tracing::trace!(bytes, left = prev + bytes, "Released allocation handle");
        self.bytes = 0;
    }
}

impl AllocHandle {
    /// Returns the tracker the handle belongs to.
    pub fn tracker(&self) -> AllocTracker {
        AllocTracker {
            inner: Arc::clone(&self.inner),
        }
    }
}