1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
extern crate num_traits;

use std::cmp::Ordering;
use std::ops::Add;
use std::ops::Sub;
use num_traits::Bounded;

pub struct UserDataByRef;
pub struct UserDataOwned;

pub trait MetricSpace {
    type UserData;
    type Distance: Copy + PartialOrd + Bounded + Add<Output=Self::Distance> + Sub<Output=Self::Distance>;

    /**
     * This function must return distance between two items that meets triangle inequality.
     * Specifically, it can't return squared distance (you must use sqrt if you use Euclidean distance)
     *
     * @param user_data Whatever you want. Passed from new_with_user_data()
     */
    fn distance(&self, other: &Self, user_data: &Self::UserData) -> Self::Distance;
}

pub trait BestCandidate<T> {
    fn new() -> Self;
    fn consider(&mut self, distance: T, candidate_index: usize);
}

impl<Item: MetricSpace> BestCandidate<<Item as MetricSpace>::Distance> for Tmp<Item>  {
    fn new() -> Self {
        Tmp {
            distance: <Item::Distance as Bounded>::max_value(),
            idx: 0,
        }
    }

    #[inline]
    fn consider(&mut self, distance: Item::Distance, candidate_index: usize) {
        if distance < self.distance {
            self.distance = distance;
            self.idx = candidate_index;
        }
    }
}

struct Node<Item: MetricSpace + Copy> {
    near: Option<Box<Node<Item>>>,
    far: Option<Box<Node<Item>>>,
    vantage_point: Item, // Pointer to the item (value) represented by the current node
    radius: Item::Distance,    // How far the `near` node stretches
    idx: usize,             // Index of the `vantage_point` in the original items array
}

pub struct Tree<Item: MetricSpace + Copy, Ownership> {
    root: Node<Item>,
    user_data: Option<Item::UserData>,
    _ownership: Ownership,
}

/* Temporary object used to reorder/track distance between items without modifying the orignial items array
   (also used during search to hold the two properties).
*/
struct Tmp<Item: MetricSpace> {
    distance: Item::Distance,
    idx: usize,
}

impl<Item: MetricSpace<UserData = ()> + Copy> Tree<Item, UserDataOwned> {

    /**
     * @see Tree::new_with_user_data_owned
     */
    pub fn new(items: &[Item]) -> Tree<Item, UserDataOwned> {
        Self::new_with_user_data_owned(items, ())
    }
}

impl<T, Item: MetricSpace<UserData = T> + Copy> Tree<Item, UserDataOwned> {
    /**
     * Finds item closest to given needle (that can be any item) and returns *index* of the item in items array from vp_init.
     *
     * @param  needle       The query.
     * @return              Index of the nearest item found and the distance from the nearest item
     */
    pub fn find_nearest(&self, needle: &Item) -> (usize, Item::Distance) {
        self.find_nearest_with_user_data(needle, &self.user_data.as_ref().unwrap())
    }
}

impl<Item: MetricSpace + Copy, Ownership> Tree<Item, Ownership> {
    fn sort_indexes_by_distance(vantage_point: Item, indexes: &mut [Tmp<Item>], items: &[Item], user_data: &Item::UserData) {
        for i in indexes.iter_mut() {
            i.distance = vantage_point.distance(&items[i.idx], user_data);
        }
        indexes.sort_by(|a, b| if a.distance < b.distance {Ordering::Less} else {Ordering::Greater});
    }

    fn create_node(indexes: &mut [Tmp<Item>], items: &[Item], user_data: &Item::UserData) -> Option<Node<Item>> {
        if indexes.len() == 0 {
            return None;
        }

        if indexes.len() == 1 {
            return Some(Node{
                near: None, far: None,
                vantage_point: items[indexes[0].idx],
                idx: indexes[0].idx,
                radius: <Item::Distance as Bounded>::max_value(),
            });
        }

        let ref_idx = indexes[0].idx;

        // Removes the `ref_idx` item from remaining items, because it's included in the current node
        let rest = &mut indexes[1..];

        Self::sort_indexes_by_distance(items[ref_idx], rest, items, user_data);

        // Remaining items are split by the median distance
        let half_idx = rest.len()/2;

        let (near_indexes, far_indexes) = rest.split_at_mut(half_idx);

        Some(Node{
            vantage_point: items[ref_idx],
            idx: ref_idx,
            radius: far_indexes[0].distance,
            near: Self::create_node(near_indexes, items, user_data).map(|i| Box::new(i)),
            far: Self::create_node(far_indexes, items, user_data).map(|i| Box::new(i)),
        })
    }
}

impl<Item: MetricSpace + Copy> Tree<Item, UserDataOwned> {
    /**
     * Create a Vantage Point tree for fast nearest neighbor search.
     *
     * @param  items        Array of items that will be searched.
     * @param  user_data    Reference to any object that is passed down to item.distance()
     */
    pub fn new_with_user_data_owned(items: &[Item], user_data: Item::UserData) -> Tree<Item, UserDataOwned> {
        Tree {
            root: Self::create_root_node(items, &user_data),
            user_data: Some(user_data),
            _ownership: UserDataOwned,
        }
    }
}

impl<Item: MetricSpace + Copy> Tree<Item, UserDataByRef> {
    pub fn new_with_user_data_ref(items: &[Item], user_data: &Item::UserData) -> Tree<Item, UserDataByRef> {
        Tree {
            root: Self::create_root_node(items, &user_data),
            user_data: None,
            _ownership: UserDataByRef,
        }
    }

    pub fn find_nearest(&self, needle: &Item, user_data: &Item::UserData) -> (usize, Item::Distance) {
        self.find_nearest_with_user_data(needle, user_data)
    }
}

use std::fmt::{Debug,Formatter,Error};
impl<Item: Debug + Copy + MetricSpace, Ownership> Debug for Tree<Item, Ownership> {
    fn fmt(&self, f:&mut Formatter) -> Result<(),Error> {
        write!(f, "digraph \"vp tree.dot\" {{\n{:?}}}", self.root)
    }
}

impl<Item: Debug + Copy + MetricSpace> Debug for Node<Item> {
    fn fmt(&self, f:&mut Formatter) -> Result<(),Error> {
        if self.near.is_some() {
            try!(write!(f, "\"{:?}\" -> \"{:?}\"\n", self.vantage_point, self.near.as_ref().unwrap().vantage_point));
            try!(self.near.as_ref().unwrap().fmt(f));
        }
        if self.far.is_some() {
            try!(write!(f, "\"{:?}\" -> \"{:?}\"\n", self.vantage_point, self.far.as_ref().unwrap().vantage_point));
            try!(self.far.as_ref().unwrap().fmt(f));
        }
        return Ok(());
    }
}

impl<Item: MetricSpace + Copy, Ownership> Tree<Item, Ownership> {
    fn create_root_node(items: &[Item], user_data: &Item::UserData) -> Node<Item> {
        let mut indexes: Vec<_> = (0..items.len()).map(|i| Tmp{
            idx:i, distance: <Item::Distance as Bounded>::max_value(),
        }).collect();

        Self::create_node(&mut indexes[..], items, user_data).unwrap()
    }

    fn search_node(node: &Node<Item>, needle: &Item, best_candidate: &mut Tmp<Item>, user_data: &Item::UserData) {
        let distance = needle.distance(&node.vantage_point, user_data);

        best_candidate.consider(distance, node.idx);

        // Recurse towards most likely candidate first to narrow best candidate's distance as soon as possible
        if distance < node.radius {
            if let Some(ref near) = node.near {
                Self::search_node(&*near, needle, best_candidate, user_data);
            }
            // The best node (final answer) may be just ouside the radius, but not farther than
            // the best distance we know so far. The search_node above should have narrowed
            // best_candidate.distance, so this path is rarely taken.
            if let Some(ref far) = node.far {
                if distance >= node.radius - best_candidate.distance {
                    Self::search_node(&*far, needle, best_candidate, user_data);
                }
            }
        } else {
            if let Some(ref far) = node.far {
                Self::search_node(&*far, needle, best_candidate, user_data);
            }
            if let Some(ref near) = node.near {
                if distance <= node.radius + best_candidate.distance {
                    Self::search_node(&*near, needle, best_candidate, user_data);
                }
            }
        }
    }

    fn find_nearest_with_user_data(&self, needle: &Item, user_data: &Item::UserData) -> (usize, Item::Distance) {
        let mut best_candidate = Tmp::new();
        Self::search_node(&self.root, needle, &mut best_candidate, user_data);

        (best_candidate.idx, best_candidate.distance)
    }
}

// Test

#[cfg(test)]
#[derive(Copy, Clone)]
struct Foo(f32);

#[cfg(test)]
impl MetricSpace for Foo {
    type Distance = f32;
    type UserData = ();
    fn distance(&self, other: &Self, _: &Self::UserData) -> Self::Distance {
        (self.0 - other.0).abs()
    }
}

#[cfg(test)]
#[derive(Copy, Clone)]
struct Bar(i32);

#[cfg(test)]
impl MetricSpace for Bar {
    type UserData = usize;
    type Distance = u32;

    fn distance(&self, other: &Self, user_data: &Self::UserData) -> Self::Distance {
        assert_eq!(12345, *user_data);

        (self.0 - other.0).abs() as u32
    }
}

#[test]
fn test_without_user_data() {
    let foos = [Foo(1.0), Foo(1.5), Foo(2.0)];
    let vp = Tree::new(&foos);

    assert_eq!((2, 98.0), vp.find_nearest(&Foo(100.0)));
    assert_eq!((0, 101.0), vp.find_nearest(&Foo(-100.0)));
    assert_eq!((1, 0.0), vp.find_nearest(&Foo(1.5)));
    assert_eq!((1, 0.125), vp.find_nearest(&Foo(1.5-0.125)));
    assert_eq!((2, 0.125), vp.find_nearest(&Foo(2.0-0.125)));
}

#[test]
fn test_with_user_data() {
    let bars = [Bar(10), Bar(15), Bar(20)];
    let magic = 12345;
    let vp = Tree::new_with_user_data_owned(&bars, magic);

    assert_eq!((1, 0), vp.find_nearest(&Bar(15)));
    assert_eq!((1, 1), vp.find_nearest_with_user_data(&Bar(16), &magic));

    let vp = Tree::new_with_user_data_ref(&bars, &magic);
    assert_eq!((0, 1), vp.find_nearest(&Bar(9), &magic));
    assert_eq!((0, 1), vp.find_nearest_with_user_data(&Bar(9), &magic));
}