1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
//! Provides HTTP proxy functionality.

use crate::config::LoadBalancerMode;
use crate::rand::{Choose, Lcg};
use crate::server::server::AppState;

use humphrey::http::headers::HeaderType;
use humphrey::http::proxy::proxy_request;
use humphrey::http::{Request, Response, StatusCode};

use std::net::ToSocketAddrs;
use std::sync::{Arc, Mutex, MutexGuard, PoisonError};
use std::time::Duration;

/// Represents a load balancer.
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct LoadBalancer {
    /// The targets of the load balancer.
    pub targets: Vec<String>,
    /// The algorithm used to choose a target.
    pub mode: LoadBalancerMode,
    /// The current target.
    pub index: usize,
    /// The random number generator used by the load balancer.
    pub lcg: Lcg,
}

impl LoadBalancer {
    /// Selects a target according to the load balancer mode.
    pub fn select_target(&mut self) -> String {
        match self.mode {
            LoadBalancerMode::RoundRobin => {
                let target_index = self.index;
                self.index += 1;
                if self.index == self.targets.len() {
                    self.index = 0;
                }

                self.targets[target_index].clone()
            }
            LoadBalancerMode::Random => self.targets.choose(&mut self.lcg).unwrap().clone(),
        }
    }
}

/// Handles proxy requests.
pub fn proxy_handler(
    request: Request,
    state: Arc<AppState>,
    load_balancer: &EqMutex<LoadBalancer>,
    matches: &str,
) -> Response {
    let mut simplified_uri = request.uri.clone();

    for ch in matches.chars() {
        if ch != '*' {
            simplified_uri.remove(0);
        } else {
            break;
        }
    }

    if !simplified_uri.starts_with('/') {
        simplified_uri.insert(0, '/');
    }

    // Return error 403 if the address was blacklisted
    if state
        .config
        .blacklist
        .list
        .contains(&request.address.origin_addr)
    {
        state.logger.warn(&format!(
            "{}: Blacklisted IP attempted to request {}",
            request.address, request.uri
        ));
        Response::empty(StatusCode::Forbidden)
            .with_header(HeaderType::ContentType, "text/html")
            .with_bytes(b"<h1>403 Forbidden</h1>")
    } else {
        // Gets a load balancer target using the thread-safe `Mutex`
        let mut load_balancer_lock = load_balancer.lock().unwrap();
        let target = load_balancer_lock.select_target();
        drop(load_balancer_lock);

        let mut proxied_request = request.clone();
        proxied_request.uri = simplified_uri;

        let target_sock = target.to_socket_addrs().unwrap().next().unwrap();
        let response = proxy_request(&proxied_request, target_sock, Duration::from_secs(5));
        let status: u16 = response.status_code.into();
        let status_string: &str = response.status_code.into();

        state.logger.info(&format!(
            "{}: {} {} {}",
            request.address, status, status_string, request.uri
        ));

        response
    }
}

/// A `Mutex` which implements `PartialEq` for testing.
#[derive(Debug)]
pub struct EqMutex<T> {
    mutex: Mutex<T>,
}

impl<T> EqMutex<T> {
    /// Locks the mutex.
    pub fn lock(&self) -> Result<MutexGuard<T>, PoisonError<MutexGuard<T>>> {
        self.mutex.lock()
    }

    /// Creates a new mutex.
    pub fn new(data: T) -> Self {
        Self {
            mutex: Mutex::new(data),
        }
    }
}

impl<T> PartialEq for EqMutex<T>
where
    T: PartialEq,
{
    fn eq(&self, other: &Self) -> bool {
        *self.lock().unwrap() == *other.lock().unwrap()
    }
}