mysql/conn/local_infile.rs
1// Copyright (c) 2020 rust-mysql-simple contributors
2//
3// Licensed under the Apache License, Version 2.0
4// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. All files in the project carrying such notice may not be copied,
7// modified, or distributed except according to those terms.
8
9use std::{
10 fmt, io,
11 sync::{Arc, Mutex},
12};
13
14use crate::Conn;
15
16pub(crate) type LocalInfileInner =
17 Arc<Mutex<dyn for<'a> FnMut(&'a [u8], &'a mut LocalInfile<'_>) -> io::Result<()> + Send>>;
18
19/// Callback to handle requests for local files.
20/// Consult [Mysql documentation](https://dev.mysql.com/doc/refman/5.7/en/load-data.html) for the
21/// format of local infile data.
22///
23/// # Support
24///
25/// Note that older versions of Mysql server may not support this functionality.
26///
27/// ```rust
28/// # use std::io::Write;
29/// # use mysql::{
30/// # Pool,
31/// # Opts,
32/// # OptsBuilder,
33/// # LocalInfileHandler,
34/// # from_row,
35/// # error::Error,
36/// # prelude::*,
37/// # };
38/// use mysql::prelude::Queryable;
39/// # fn get_opts() -> Opts {
40/// # let url = if let Ok(url) = std::env::var("DATABASE_URL") {
41/// # let opts = Opts::from_url(&url).expect("DATABASE_URL invalid");
42/// # if opts.get_db_name().expect("a database name is required").is_empty() {
43/// # panic!("database name is empty");
44/// # }
45/// # url
46/// # } else {
47/// # "mysql://root:password@127.0.0.1:3307/mysql".to_string()
48/// # };
49/// # Opts::from_url(&*url).unwrap()
50/// # }
51/// # let opts = get_opts();
52/// # let pool = Pool::new_manual(1, 1, opts).unwrap();
53/// # let mut conn = pool.get_conn().unwrap();
54/// # conn.query_drop("CREATE TEMPORARY TABLE mysql.Users (id INT, name TEXT, age INT, email TEXT)").unwrap();
55/// # conn.exec_drop("INSERT INTO mysql.Users (id, name, age, email) VALUES (?, ?, ?, ?)",
56/// # (1, "John", 17, "foo@bar.baz")).unwrap();
57/// conn.query_drop("CREATE TEMPORARY TABLE mysql.tbl(a TEXT)").unwrap();
58///
59/// conn.set_local_infile_handler(Some(
60/// LocalInfileHandler::new(|file_name, writer| {
61/// writer.write_all(b"row1: file name is ")?;
62/// writer.write_all(file_name)?;
63/// writer.write_all(b"\n")?;
64///
65/// writer.write_all(b"row2: foobar\n")
66/// })
67/// ));
68///
69/// match conn.query_drop("LOAD DATA LOCAL INFILE 'file_name' INTO TABLE mysql.tbl") {
70/// Ok(_) => (),
71/// Err(Error::MySqlError(ref e)) if e.code == 1148 => {
72/// // functionality is not supported by the server
73/// return;
74/// }
75/// err => {
76/// err.unwrap();
77/// }
78/// }
79///
80/// let mut row_num = 0;
81/// let result: Vec<String> = conn.query("SELECT * FROM mysql.tbl").unwrap();
82/// assert_eq!(
83/// result,
84/// vec!["row1: file name is file_name".to_string(), "row2: foobar".to_string()],
85/// );
86/// ```
87#[derive(Clone)]
88pub struct LocalInfileHandler(pub(crate) LocalInfileInner);
89
90impl LocalInfileHandler {
91 pub fn new<F>(f: F) -> Self
92 where
93 F: for<'a> FnMut(&'a [u8], &'a mut LocalInfile<'_>) -> io::Result<()> + Send + 'static,
94 {
95 LocalInfileHandler(Arc::new(Mutex::new(f)))
96 }
97}
98
99impl PartialEq for LocalInfileHandler {
100 fn eq(&self, other: &LocalInfileHandler) -> bool {
101 (&*self.0 as *const _) == (&*other.0 as *const _)
102 }
103}
104
105impl Eq for LocalInfileHandler {}
106
107impl fmt::Debug for LocalInfileHandler {
108 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
109 write!(f, "LocalInfileHandler(...)")
110 }
111}
112
113/// Local in-file stream.
114/// The callback will be passed a reference to this stream, which it
115/// should use to write the contents of the requested file.
116/// See [LocalInfileHandler](struct.LocalInfileHandler.html) documentation for example.
117#[derive(Debug)]
118pub struct LocalInfile<'a> {
119 buffer: io::Cursor<Box<[u8]>>,
120 conn: &'a mut Conn,
121}
122
123impl<'a> LocalInfile<'a> {
124 pub(crate) fn new(buffer: io::Cursor<Box<[u8]>>, conn: &'a mut Conn) -> Self {
125 Self { buffer, conn }
126 }
127}
128
129impl<'a> io::Write for LocalInfile<'a> {
130 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
131 let result = self.buffer.write(buf);
132 if result.is_ok() && self.buffer.position() as usize >= self.buffer.get_ref().len() {
133 self.flush()?;
134 }
135 result
136 }
137
138 fn flush(&mut self) -> io::Result<()> {
139 let n = self.buffer.position() as usize;
140 if n > 0 {
141 let mut range = &self.buffer.get_ref()[..n];
142 self.conn
143 .write_packet(&mut range)
144 .map_err(|e| io::Error::new(io::ErrorKind::Other, Box::new(e)))?;
145 }
146 self.buffer.set_position(0);
147 Ok(())
148 }
149}