1use std::fmt;
20
21use reqwest::StatusCode;
22
23use serde::{
24 de::{Error, MapAccess, Visitor},
25 Deserialize, Deserializer, Serialize, Serializer,
26};
27
28use serde_json;
29
30use crate::{
31 error::EsError,
32 json::{FieldBased, NoOuter, ShouldSkip},
33 units::Duration,
34 Client, EsResponse,
35};
36
37use super::{
38 common::{OptionVal, Options, VersionType},
39 ShardCountResult,
40};
41
42#[derive(Debug)]
43pub enum ActionType {
44 Index,
45 Create,
46 Delete,
47 Update,
49}
50
51impl Serialize for ActionType {
52 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
53 where
54 S: Serializer,
55 {
56 self.to_string().serialize(serializer)
57 }
58}
59
60impl ToString for ActionType {
61 fn to_string(&self) -> String {
62 match *self {
63 ActionType::Index => "index",
64 ActionType::Create => "create",
65 ActionType::Delete => "delete",
66 ActionType::Update => "update",
67 }
68 .to_owned()
69 }
70}
71
72#[derive(Debug, Default, Serialize)]
73pub struct ActionOptions {
74 #[serde(rename = "_index", skip_serializing_if = "ShouldSkip::should_skip")]
75 index: Option<String>,
76 #[serde(rename = "_type", skip_serializing_if = "ShouldSkip::should_skip")]
77 doc_type: Option<String>,
78 #[serde(rename = "_id", skip_serializing_if = "ShouldSkip::should_skip")]
79 id: Option<String>,
80 #[serde(rename = "_version", skip_serializing_if = "ShouldSkip::should_skip")]
81 version: Option<u64>,
82 #[serde(
83 rename = "_version_type",
84 skip_serializing_if = "ShouldSkip::should_skip"
85 )]
86 version_type: Option<VersionType>,
87 #[serde(rename = "_routing", skip_serializing_if = "ShouldSkip::should_skip")]
88 routing: Option<String>,
89 #[serde(rename = "_parent", skip_serializing_if = "ShouldSkip::should_skip")]
90 parent: Option<String>,
91 #[serde(rename = "_timestamp", skip_serializing_if = "ShouldSkip::should_skip")]
92 timestamp: Option<String>,
93 #[serde(rename = "_ttl", skip_serializing_if = "ShouldSkip::should_skip")]
94 ttl: Option<Duration>,
95 #[serde(
96 rename = "_retry_on_conflict",
97 skip_serializing_if = "ShouldSkip::should_skip"
98 )]
99 retry_on_conflict: Option<u64>,
100}
101
102#[derive(Debug, Serialize)]
103pub struct Action<X>(FieldBased<ActionType, ActionOptions, NoOuter>, Option<X>);
104
105impl<S> Action<S>
106where
107 S: Serialize,
108{
109 pub fn index(document: S) -> Self {
114 Action(
115 FieldBased::new(ActionType::Index, Default::default(), NoOuter),
116 Some(document),
117 )
118 }
119
120 pub fn create(document: S) -> Self {
122 Action(
123 FieldBased::new(ActionType::Create, Default::default(), NoOuter),
124 Some(document),
125 )
126 }
127
128 fn add(&self, actstr: &mut String) -> Result<(), EsError> {
130 let command_str = serde_json::to_string(&self.0)?;
131
132 actstr.push_str(&command_str);
133 actstr.push_str("\n");
134
135 if let Some(ref source) = self.1 {
136 let payload_str = serde_json::to_string(source)?;
137 actstr.push_str(&payload_str);
138 actstr.push_str("\n");
139 }
140 Ok(())
141 }
142}
143
144impl<S> Action<S> {
145 pub fn delete<A: Into<String>>(id: A) -> Self {
156 Action(
157 FieldBased::new(
158 ActionType::Delete,
159 ActionOptions {
160 id: Some(id.into()),
161 ..Default::default()
162 },
163 NoOuter,
164 ),
165 None,
166 )
167 }
168
169 add_inner_field!(with_index, index, String);
172 add_inner_field!(with_doc_type, doc_type, String);
173 add_inner_field!(with_id, id, String);
174 add_inner_field!(with_version, version, u64);
175 add_inner_field!(with_version_type, version_type, VersionType);
176 add_inner_field!(with_routing, routing, String);
177 add_inner_field!(with_parent, parent, String);
178 add_inner_field!(with_timestamp, timestamp, String);
179 add_inner_field!(with_ttl, ttl, Duration);
180 add_inner_field!(with_retry_on_conflict, retry_on_conflict, u64);
181}
182
183#[derive(Debug)]
184pub struct BulkOperation<'a, 'b, S: 'b> {
185 client: &'a mut Client,
186 index: Option<&'b str>,
187 doc_type: Option<&'b str>,
188 actions: &'b [Action<S>],
189 options: Options<'b>,
190}
191
192impl<'a, 'b, S> BulkOperation<'a, 'b, S>
193where
194 S: Serialize,
195{
196 pub fn new(client: &'a mut Client, actions: &'b [Action<S>]) -> Self {
197 BulkOperation {
198 client,
199 index: None,
200 doc_type: None,
201 actions,
202 options: Options::default(),
203 }
204 }
205
206 pub fn with_index(&'b mut self, index: &'b str) -> &'b mut Self {
207 self.index = Some(index);
208 self
209 }
210
211 pub fn with_doc_type(&'b mut self, doc_type: &'b str) -> &'b mut Self {
212 self.doc_type = Some(doc_type);
213 self
214 }
215
216 add_option!(with_consistency, "consistency");
217 add_option!(with_refresh, "refresh");
218
219 fn format_url(&self) -> String {
220 let mut url = String::new();
221 url.push_str("/");
222 if let Some(index) = self.index {
223 url.push_str(index);
224 url.push_str("/");
225 }
226 if let Some(doc_type) = self.doc_type {
227 url.push_str(doc_type);
228 url.push_str("/");
229 }
230 url.push_str("_bulk");
231 url.push_str(&self.options.to_string());
232 url
233 }
234
235 fn format_actions(&self) -> String {
236 let mut actstr = String::new();
237 for action in self.actions {
238 action.add(&mut actstr).unwrap();
239 }
240 actstr
241 }
242
243 pub fn send(&self) -> Result<BulkResult, EsError> {
244 let response = self.client.do_es_op(&self.format_url(), |url| {
252 self.client
253 .http_client
254 .post(url)
255 .body(self.format_actions())
256 })?;
257
258 match response.status_code() {
259 StatusCode::OK => Ok(response.read_response()?),
260 status_code => Err(EsError::EsError(format!(
261 "Unexpected status: {}",
262 status_code
263 ))),
264 }
265 }
266}
267
268impl Client {
269 pub fn bulk<'a, 'b, S>(&'a mut self, actions: &'b [Action<S>]) -> BulkOperation<'a, 'b, S>
273 where
274 S: Serialize,
275 {
276 BulkOperation::new(self, actions)
277 }
278}
279
280#[derive(Debug)]
282pub struct ActionResult {
283 pub action: ActionType,
284 pub inner: ActionResultInner,
285}
286
287impl<'de> Deserialize<'de> for ActionResult {
288 fn deserialize<D>(deserializer: D) -> Result<ActionResult, D::Error>
289 where
290 D: Deserializer<'de>,
291 {
292 struct ActionResultVisitor;
293
294 impl<'vde> Visitor<'vde> for ActionResultVisitor {
295 type Value = ActionResult;
296
297 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
298 formatter.write_str("an ActionResult")
299 }
300
301 fn visit_map<V>(self, mut visitor: V) -> Result<ActionResult, V::Error>
302 where
303 V: MapAccess<'vde>,
304 {
305 let visited: Option<(String, ActionResultInner)> = visitor.next_entry()?;
306 let (key, value) = match visited {
307 Some((key, value)) => (key, value),
308 None => return Err(V::Error::custom("expecting at least one field")),
309 };
310
311 let result = ActionResult {
312 action: match key.as_ref() {
313 "index" => ActionType::Index,
314 "create" => ActionType::Create,
315 "delete" => ActionType::Delete,
316 "update" => ActionType::Update,
317 _ => return Err(V::Error::custom(format!("Unrecognised key: {}", key))),
318 },
319 inner: value,
320 };
321
322 Ok(result)
323 }
324 }
325
326 deserializer.deserialize_any(ActionResultVisitor)
327 }
328}
329
330#[derive(Debug, serde::Deserialize)]
331pub struct ActionResultInner {
332 #[serde(rename = "_index")]
333 pub index: String,
334 #[serde(rename = "_type")]
335 pub doc_type: String,
336 #[serde(rename = "_version")]
337 pub version: u64,
338 pub status: u64,
339 #[serde(rename = "_shards")]
340 pub shards: ShardCountResult,
341 pub found: Option<bool>,
342}
343
344#[derive(Debug, serde::Deserialize)]
346pub struct BulkResult {
347 pub errors: bool,
348 pub items: Vec<ActionResult>,
349 pub took: u64,
350}
351
352#[cfg(test)]
353pub mod tests {
354 use crate::tests::{clean_db, make_client, TestDocument};
355
356 use super::Action;
357
358 #[test]
359 fn test_bulk() {
360 let index_name = "test_bulk";
361 let mut client = make_client();
362
363 clean_db(&mut client, index_name);
364
365 let actions: Vec<Action<TestDocument>> = (1..10)
366 .map(|i| {
367 let doc = TestDocument::new()
368 .with_str_field("bulk_doc")
369 .with_int_field(i);
370 Action::index(doc)
371 })
372 .collect();
373
374 let result = client
375 .bulk(&actions)
376 .with_index(index_name)
377 .with_doc_type("bulk_type")
378 .send()
379 .unwrap();
380
381 assert_eq!(false, result.errors);
382 assert_eq!(9, result.items.len());
383 }
384}