Skip to main content

auto_derive/
lib.rs

1// Copyright 2026 Mahmoud Harmouch.
2//
3// Licensed under the MIT license
4// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
5// option. This file may not be copied, modified, or distributed
6// except according to those terms.
7
8extern crate proc_macro;
9
10use quote::quote;
11use syn::{DeriveInput, parse_macro_input};
12
13#[proc_macro_derive(Auto)]
14pub fn derive_agent(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
15    let input = parse_macro_input!(input as DeriveInput);
16    let name = &input.ident;
17
18    let expanded = quote! {
19        impl Agent for #name {
20            fn new(persona: Cow<'static, str>, behavior: Cow<'static, str>) -> Self {
21                let mut agent = Self::default();
22                agent.agent.persona = persona;
23                agent.agent.behavior = behavior;
24                agent
25            }
26
27            fn update(&mut self, status: Status) {
28                self.agent.update(status);
29            }
30
31            fn behavior(&self) -> &std::borrow::Cow<'static, str> {
32                &self.agent.behavior
33            }
34
35            fn persona(&self) -> &std::borrow::Cow<'static, str> {
36                &self.agent.persona
37            }
38
39            fn status(&self) -> &Status {
40                &self.agent.status
41            }
42
43            fn memory(&self) -> &Vec<Message> {
44                &self.agent.memory
45            }
46
47            fn tools(&self) -> &Vec<Tool> {
48                &self.agent.tools
49            }
50
51            fn knowledge(&self) -> &Knowledge {
52                &self.agent.knowledge
53            }
54
55            fn planner(&self) -> Option<&Planner> {
56                self.agent.planner.as_ref()
57            }
58
59            fn profile(&self) -> &Persona {
60                &self.agent.profile
61            }
62
63            #[cfg(feature = "net")]
64            fn collaborators(&self) -> Vec<Collaborator> {
65                let mut all = Vec::new();
66                all.extend(self.agent.local_collaborators.values().cloned());
67                all.extend(self.agent.remote_collaborators.values().cloned());
68                all
69            }
70
71            fn reflection(&self) -> Option<&Reflection> {
72                self.agent.reflection.as_ref()
73            }
74
75            fn scheduler(&self) -> Option<&TaskScheduler> {
76                self.agent.scheduler.as_ref()
77            }
78
79            fn capabilities(&self) -> &std::collections::HashSet<Capability> {
80                &self.agent.capabilities
81            }
82
83            fn context(&self) -> &ContextManager {
84                &self.agent.context
85            }
86
87            fn tasks(&self) -> &Vec<Task> {
88                &self.agent.tasks
89            }
90
91            fn memory_mut(&mut self) -> &mut Vec<Message> {
92                &mut self.agent.memory
93            }
94
95            fn planner_mut(&mut self) -> Option<&mut Planner> {
96                self.agent.planner.as_mut()
97            }
98
99            fn context_mut(&mut self) -> &mut ContextManager {
100                &mut self.agent.context
101            }
102        }
103
104        impl Functions for #name {
105            fn get_agent(&self) -> &AgentGPT {
106                &self.agent
107            }
108        }
109
110        #[async_trait]
111        impl AsyncFunctions for #name {
112            async fn execute<'a>(
113                &'a mut self,
114                task: &'a mut Task,
115                execute: bool,
116                browse: bool,
117                max_tries: u64,
118            ) -> Result<()> {
119                <#name as Executor>::execute(self, task, execute, browse, max_tries).await
120            }
121
122            /// Saves a communication to long-term memory for the agent.
123            ///
124            /// # Arguments
125            ///
126            /// * `communication` - The communication to save, which contains the role and content.
127            ///
128            /// # Returns
129            ///
130            /// (`Result<()>`): Result indicating the success or failure of saving the communication.
131            ///
132            /// # Business Logic
133            ///
134            /// - This method uses the `save_long_term_memory` util function to save the communication into the agent's long-term memory.
135            /// - The communication is embedded and stored using the agent's unique ID as the namespace.
136            /// - It handles the embedding and metadata for the communication, ensuring it's stored correctly.
137            #[cfg(feature = "mem")]
138            async fn save_ltm(&mut self, message: Message) -> Result<()> {
139                save_long_term_memory(&mut self.client, self.agent.id.clone(), message).await
140            }
141
142            /// Retrieves all communications stored in the agent's long-term memory.
143            ///
144            /// # Returns
145            ///
146            /// (`Result<Vec<Message>>`): A result containing a vector of communications retrieved from the agent's long-term memory.
147            ///
148            /// # Business Logic
149            ///
150            /// - This method fetches the stored communications for the agent by interacting with the `load_long_term_memory` function.
151            /// - The function will return a list of communications that are indexed by the agent's unique ID.
152            /// - It handles the retrieval of the stored metadata and content for each communication.
153            #[cfg(feature = "mem")]
154            async fn get_ltm(&self) -> Result<Vec<Message>> {
155                load_long_term_memory(self.agent.id.clone()).await
156            }
157
158            /// Retrieves the concatenated context of all communications in the agent's long-term memory.
159            ///
160            /// # Returns
161            ///
162            /// (`String`): A string containing the concatenated role and content of all communications stored in the agent's long-term memory.
163            ///
164            /// # Business Logic
165            ///
166            /// - This method calls the `long_term_memory_context` function to generate a string representation of the agent's entire long-term memory.
167            /// - The context string is composed of each communication's role and content, joined by new lines.
168            /// - It provides a quick overview of the agent's memory in a human-readable format.
169            #[cfg(feature = "mem")]
170            async fn ltm_context(&self) -> String {
171                long_term_memory_context(self.agent.id.clone()).await
172            }
173
174            async fn generate(&mut self, request: &str) -> Result<String> {
175                match &mut self.client {
176                    #[cfg(feature = "gem")]
177                    ClientType::Gemini(gem_client) => {
178                        let parameters = ChatBuilder::default()
179                            .messages(vec![gems::messages::Message::User {
180                                content: Content::Text(request.to_string()),
181                                name: None,
182                            }])
183                            .build()?;
184
185                        let result = gem_client.chat().generate(parameters).await;
186                        Ok(result.unwrap_or_default())
187                    }
188
189                    #[cfg(feature = "oai")]
190                    ClientType::OpenAI(oai_client) => {
191                        let parameters = ChatCompletionParametersBuilder::default()
192                            .model(Gpt4Model::Gpt4O.to_string())
193                            .messages(vec![ChatMessage::User {
194                                content: ChatMessageContent::Text(request.to_string()),
195                                name: None,
196                            }])
197                            .response_format(ChatCompletionResponseFormat::Text)
198                            .build()?;
199
200                        let result = oai_client.chat().create(parameters).await?;
201                        let message = &result.choices[0].message;
202
203                        Ok(match message {
204                            ChatMessage::Assistant {
205                                content: Some(chat_content),
206                                ..
207                            } => chat_content.to_string(),
208                            ChatMessage::User { content, .. } => content.to_string(),
209                            ChatMessage::System { content, .. } => content.to_string(),
210                            ChatMessage::Developer { content, .. } => content.to_string(),
211                            ChatMessage::Tool { content, .. } => content.to_string(),
212                            _ => String::new(),
213                        })
214                    }
215
216                    #[cfg(feature = "cld")]
217                    ClientType::Anthropic(client) => {
218                        let body = CreateMessageParams::new(RequiredMessageParams {
219                            model: "claude-3-7-sonnet-latest".to_string(),
220                            messages: vec![AnthMessage::new_text(Role::User, request.to_string())],
221                            max_tokens: 1024,
222                        });
223
224                        let chat_response = client.create_message(Some(&body)).await?;
225                        Ok(chat_response
226                            .content
227                            .iter()
228                            .filter_map(|block| match block {
229                                ContentBlock::Text { text, .. } => Some(text.as_str()),
230                                _ => None,
231                            })
232                            .collect::<Vec<_>>()
233                            .join("\n"))
234                    }
235
236                    #[cfg(feature = "xai")]
237                    ClientType::Xai(xai_client) => {
238                        let messages = vec![XaiMessage::text("user", request)];
239
240                        let rb = ChatCompletionsRequestBuilder::new(
241                            xai_client.clone(),
242                            "grok-beta".into(),
243                            messages,
244                        )
245                        .temperature(0.0)
246                        .stream(false);
247
248                        let req = rb.clone().build()?;
249                        let chat = rb.create_chat_completion(req).await?;
250                        Ok(chat.choices[0].message.content.to_string())
251                    }
252
253                    #[cfg(feature = "co")]
254                    ClientType::Cohere(co_client) => {
255                        use cohere_rust::api::chat::ChatRequest;
256                        use cohere_rust::api::GenerateModel;
257
258                        let chat_request = ChatRequest {
259                            message: request,
260                            ..Default::default()
261                        };
262
263                        let mut receiver = match co_client.chat(&chat_request).await {
264                            Ok(rx) => rx,
265                            Err(e) => return Err(anyhow::anyhow!("Cohere API initialization failed: {}", e)),
266                        };
267                        let mut full_text = String::new();
268                        while let Some(res) = receiver.recv().await {
269                            match res {
270                                Ok(cohere_rust::api::chat::ChatStreamResponse::ChatTextGeneration { text, .. }) => {
271                                    full_text.push_str(&text);
272                                }
273                                Ok(_) => {}
274                                // Err(e) => return Err(anyhow!("Cohere chat error: {:?}", e)),
275                                Err(_) => {},
276                            }
277                        }
278                        Ok(full_text)
279                    }
280
281                    #[allow(unreachable_patterns)]
282                    _ => {
283                        return Err(anyhow!(
284                            "No valid AI client configured. Enable `co`, `gem`, `oai`, `cld`, or `xai` feature."
285                        ));
286                    }
287                }
288            }
289
290            async fn imagen(&mut self, request: &str) -> Result<Vec<u8>> {
291                match &mut self.client {
292                    #[cfg(feature = "gem")]
293                    ClientType::Gemini(gem_client) => {
294                        gem_client.set_model(Model::Imagen4);
295
296                        let input = gems::messages::Message::User {
297                            content: Content::Text(request.into()),
298                            name: None,
299                        };
300
301                        let params = ImageGenBuilder::default()
302                            .model(Model::Imagen4)
303                            .input(input)
304                            .build()?;
305
306                        let image_bytes = gem_client.images().generate(params).await;
307                        Ok(image_bytes.unwrap_or_default())
308                    }
309
310                    #[cfg(feature = "oai")]
311                    ClientType::OpenAI(oai_client) => {
312                        // TODO: Implement this
313                        Ok(Default::default())
314                    }
315
316                    #[cfg(feature = "cld")]
317                    ClientType::Anthropic(client) => {
318                        // TODO: Implement this
319                        Ok(Default::default())
320                    }
321
322                    #[cfg(feature = "xai")]
323                    ClientType::Xai(xai_client) => {
324                        // TODO: Implement this
325                        Ok(Default::default())
326                    }
327
328                    #[cfg(feature = "co")]
329                    ClientType::Cohere(_co_client) => {
330                        // Cohere does not support image generation
331                        Ok(Default::default())
332                    }
333
334                    #[allow(unreachable_patterns)]
335                    _ => {
336                        return Err(anyhow!(
337                            "No valid AI client configured. Enable `co`, `gem`, `oai`, `cld`, or `xai` feature."
338                        ));
339                    }
340                }
341            }
342
343            async fn stream(&mut self, request: &str) -> Result<ReqResponse> {
344                let request_owned = request.to_string();
345                match &mut self.client {
346                    #[cfg(feature = "gem")]
347                    ClientType::Gemini(gem_client) => {
348                        let parameters = StreamBuilder::default()
349                            .model(Model::Flash3Preview)
350                            .input(gems::messages::Message::User {
351                                content: Content::Text(request_owned.clone()),
352                                name: None,
353                            })
354                            .build()?;
355
356                        let resp = gem_client.stream().generate(parameters).await?;
357                        let (tx, rx) = tokio::sync::mpsc::channel::<String>(100);
358
359                        tokio::spawn(async move {
360                            let mut resp = resp;
361                            let mut buffer = String::new();
362
363                            while let Ok(Some(chunk)) = resp.chunk().await {
364                                if let Ok(text) = std::str::from_utf8(&chunk) {
365                                    buffer.push_str(text);
366                                    let mut parts: Vec<&str> =
367                                        buffer.split("\n\n").collect();
368                                    let new_buffer = if !buffer.ends_with("\n\n") {
369                                        parts.pop().unwrap_or("").to_string()
370                                    } else {
371                                        String::new()
372                                    };
373
374                                    for part in parts {
375                                        for line in part.lines() {
376                                            if let Some(data) =
377                                                line.strip_prefix("data: ")
378                                            {
379                                                let data = data.trim();
380                                                if data == "[DONE]" {
381                                                    continue;
382                                                }
383                                                if let Ok(json) =
384                                                    serde_json::from_str::<serde_json::Value>(data)
385                                                {
386                                                    if let Some(text) = json
387                                                        .get("candidates")
388                                                        .and_then(|c| c.get(0))
389                                                        .and_then(|c| c.get("content"))
390                                                        .and_then(|c| c.get("parts"))
391                                                        .and_then(|p| p.get(0))
392                                                        .and_then(|p| p.get("text"))
393                                                        .and_then(|t| t.as_str())
394                                                    {
395                                                        let _ = tx
396                                                            .send(text.to_string())
397                                                            .await;
398                                                    }
399                                                }
400                                            }
401                                        }
402                                    }
403
404                                    buffer = new_buffer;
405                                }
406                            }
407                        });
408
409                        Ok(ReqResponse(Some(rx)))
410                    }
411
412                    #[cfg(feature = "oai")]
413                    ClientType::OpenAI(oai_client) => {
414                        let oai_client = oai_client.clone();
415                        let request_owned = request_owned.clone();
416                        let (tx, rx) = tokio::sync::mpsc::channel::<String>(100);
417
418                        tokio::spawn(async move {
419                            use futures::StreamExt;
420
421                            let parameters =
422                                ChatCompletionParametersBuilder::default()
423                                    .model("gpt-5")
424                                    .messages(vec![ChatMessage::User {
425                                        content: ChatMessageContent::Text(
426                                            request_owned,
427                                        ),
428                                        name: None,
429                                    }])
430                                    .build()
431                                    .unwrap();
432
433                            if let Ok(mut stream) =
434                                oai_client.chat().create_stream(parameters).await
435                            {
436                                while let Some(response) = stream.next().await {
437                                    match response {
438                                        Ok(chat_response) => {
439                                            for choice in chat_response.choices {
440                                                let text_opt = match &choice.delta {
441                                                    openai_dive::v1::resources::chat::DeltaChatMessage::Assistant {
442                                                        content: Some(
443                                                            openai_dive::v1::resources::chat::ChatMessageContent::Text(text),
444                                                        ),
445                                                        ..
446                                                    } => Some(text.clone()),
447                                                    openai_dive::v1::resources::chat::DeltaChatMessage::Untagged {
448                                                        content: Some(
449                                                            openai_dive::v1::resources::chat::ChatMessageContent::Text(text),
450                                                        ),
451                                                        ..
452                                                    } => Some(text.clone()),
453                                                    _ => None,
454                                                };
455                                                if let Some(t) = text_opt {
456                                                    let _ = tx.send(t).await;
457                                                }
458                                            }
459                                        }
460                                        Err(_) => break,
461                                    }
462                                }
463                            }
464                        });
465
466                        Ok(ReqResponse(Some(rx)))
467                    }
468
469                    #[cfg(feature = "cld")]
470                    ClientType::Anthropic(client) => {
471                        let client = client.clone();
472                        let request_owned = request_owned.clone();
473                        let (tx, rx) = tokio::sync::mpsc::channel::<String>(100);
474
475                        tokio::spawn(async move {
476                            use futures::StreamExt;
477
478                            let body = CreateMessageParams::new(
479                                RequiredMessageParams {
480                                    model: "claude-opus-4-6".to_string(),
481                                    messages: vec![AnthMessage::new_text(
482                                        Role::User,
483                                        request_owned,
484                                    )],
485                                    max_tokens: 1024,
486                                },
487                            )
488                            .with_stream(true);
489
490                            if let Ok(mut stream) =
491                                client.create_message_streaming(&body).await
492                            {
493                                while let Some(event_result) = stream.next().await {
494                                    if let Ok(
495                                        anthropic_ai_sdk::types::message::StreamEvent::ContentBlockDelta {
496                                            delta,
497                                            ..
498                                        },
499                                    ) = event_result
500                                    {
501                                        if let anthropic_ai_sdk::types::message::ContentBlockDelta::TextDelta {
502                                            text,
503                                        } = delta
504                                        {
505                                            let _ = tx.send(text).await;
506                                        }
507                                    }
508                                }
509                            }
510                        });
511
512                        Ok(ReqResponse(Some(rx)))
513                    }
514
515                    #[cfg(feature = "xai")]
516                    ClientType::Xai(xai_client) => {
517                        use x_ai::traits::ClientConfig;
518
519                        let messages =
520                            vec![XaiMessage::text("user", request_owned)];
521                        let req = ChatCompletionsRequestBuilder::new(
522                            xai_client.clone(),
523                            "grok-4".into(),
524                            messages,
525                        )
526                        .stream(true)
527                        .build()?;
528
529                        let resp = x_ai::traits::ClientConfig::request(
530                            &*xai_client,
531                            reqwest::Method::POST,
532                            "chat/completions",
533                        )
534                        .map_err(|e| {
535                            anyhow::anyhow!("Failed to build xAI request: {}", e)
536                        })?
537                        .json(&req)
538                        .send()
539                        .await?;
540
541                        let (tx, rx) = tokio::sync::mpsc::channel::<String>(100);
542
543                        tokio::spawn(async move {
544                            let mut resp = resp;
545                            let mut buffer = String::new();
546
547                            while let Ok(Some(chunk)) = resp.chunk().await {
548                                if let Ok(text) = std::str::from_utf8(&chunk) {
549                                    buffer.push_str(text);
550                                    let mut parts: Vec<&str> =
551                                        buffer.split("\n\n").collect();
552                                    let new_buffer = if !buffer.ends_with("\n\n") {
553                                        parts.pop().unwrap_or("").to_string()
554                                    } else {
555                                        String::new()
556                                    };
557
558                                    for part in parts {
559                                        for line in part.lines() {
560                                            if let Some(data) =
561                                                line.strip_prefix("data: ")
562                                            {
563                                                let data = data.trim();
564                                                if data == "[DONE]" {
565                                                    continue;
566                                                }
567                                                if let Ok(json) =
568                                                    serde_json::from_str::<serde_json::Value>(data)
569                                                {
570                                                    if let Some(content) = json
571                                                        .get("choices")
572                                                        .and_then(|c| c.get(0))
573                                                        .and_then(|c| c.get("delta"))
574                                                        .and_then(|d| d.get("content"))
575                                                        .and_then(|c| c.as_str())
576                                                    {
577                                                        let _ = tx
578                                                            .send(content.to_string())
579                                                            .await;
580                                                    }
581                                                }
582                                            }
583                                        }
584                                    }
585
586                                    buffer = new_buffer;
587                                }
588                            }
589                        });
590
591                        Ok(ReqResponse(Some(rx)))
592                    }
593
594                    #[cfg(feature = "co")]
595                    ClientType::Cohere(co_client) => {
596                        let chat_request = cohere_rust::api::chat::ChatRequest {
597                            message: request,
598                            ..Default::default()
599                        };
600                        let co_result = co_client.chat(&chat_request).await;
601                        let (tx, rx) = tokio::sync::mpsc::channel::<String>(100);
602
603                        if let Ok(mut receiver) = co_result {
604                            tokio::spawn(async move {
605                                while let Some(res) = receiver.recv().await {
606                                    if let Ok(resp) = res {
607                                        if let cohere_rust::api::chat::ChatStreamResponse::ChatTextGeneration {
608                                            text, ..
609                                        } = resp
610                                        {
611                                            let _ = tx.send(text).await;
612                                        }
613                                    }
614                                }
615                            });
616                        }
617
618                        Ok(ReqResponse(Some(rx)))
619                    }
620
621                    #[allow(unreachable_patterns)]
622                    _ => {
623                        return Err(anyhow!(
624                            "No valid AI client configured. \
625                             Enable `co`, `gem`, `oai`, `cld`, or `xai` feature."
626                        ));
627                    }
628                }
629            }
630        }
631    };
632
633    proc_macro::TokenStream::from(expanded)
634}
635
636// Copyright 2026 Mahmoud Harmouch.
637//
638// Licensed under the MIT license
639// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
640// option. This file may not be copied, modified, or distributed
641// except according to those terms.